--- license: cc-by-nc-sa-4.0 --- # UNI-based ABMIL models for metastasis detection These are weakly-supervised, attention-based multiple instance learning models for binary metastasis detection (normal versus metastasis). The models were trained on the [CAMELYON16](https://camelyon16.grand-challenge.org/Data/) dataset using UNI embeddings. If you find this model useful, please cite our corresponding [preprint](https://arxiv.org/abs/2409.03080): ```bibtex @misc{kaczmarzyk2024explainableaicomputationalpathology, title={Explainable AI for computational pathology identifies model limitations and tissue biomarkers}, author={Jakub R. Kaczmarzyk and Joel H. Saltz and Peter K. Koo}, year={2024}, eprint={2409.03080}, archivePrefix={arXiv}, primaryClass={q-bio.TO}, url={https://arxiv.org/abs/2409.03080}, } ``` # Data - Training set consisted of 243 whole slide images (WSIs). - 143 negative - 100 positive - 52 macrometastases - 48 micrometastases - Validation set consisted of 27 WSIs. - 16 negative - 11 positive - 6 macrometastases - 5 micrometastases - Test set consisted of 129 WSIs. - 80 negative - 49 positive - 22 macrometastases - 27 micrometastases # Evaluation Below are the classification results on the test set. | Seed | Sensitivity | Specificity | BA | Precision | F1 | |-------:|--------------:|--------------:|------:|------------:|------:| | 0 | 0.959 | 1.000 | 0.980 | 1.000 | 0.979 | | 1 | 0.959 | 0.988 | 0.973 | 0.979 | 0.969 | | 2 | 1.000 | 1.000 | 1.000 | 1.000 | 1.000 | | 3 | 0.980 | 0.950 | 0.965 | 0.923 | 0.950 | | 4 | 0.980 | 1.000 | 0.990 | 1.000 | 0.990 | # How to reuse the model The model expects 128 x 128 micrometer patches, embedded with the UNI model. ```python import torch from abmil import AttentionMILModel model = AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2, gated_attention=True) model.eval() state_dict = torch.load("seed2/model_best.pt", map_location="cpu", weights_only=True) model.load_state_dict(state_dict) # Load a bag of features bag = torch.ones(1000, 1024) with torch.inference_mode(): logits, attention = model(bag) ``` # How to train the model Download the UNI embeddings for CAMELYON16 from https://huggingface.co/datasets/kaczmarj/camelyon16-uni and then, run the commands below. ```shell # Seed 0 python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed0 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 0 -L 512 -D 384 --lr 1e-4 # Seed 1 python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed1 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 1 -L 512 -D 384 --lr 1e-4 # Seed 2 python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed2 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 2 -L 512 -D 384 --lr 1e-4 # Seed 3 python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed3 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 3 -L 512 -D 384 --lr 1e-4 # Seed 4 python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed4 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 4 -L 512 -D 384 --lr 1e-4 ```