Fine-Tuning ResNet50 for Alzheimer's MRI Classification
This repository contains a Jupyter Notebook for fine-tuning a ResNet50 model to classify Alzheimer's disease stages from MRI images. The notebook uses PyTorch and the dataset is loaded from the Hugging Face Datasets library.
Table of Contents
- Introduction
- Dataset
- Model Architecture
- Setup
- Training
- Evaluation
- Usage
- Results
- Contributing
- License
Introduction
This notebook fine-tunes a pre-trained ResNet50 model to classify MRI images into one of four stages of Alzheimer's disease:
- Mild Demented
- Moderate Demented
- Non-Demented
- Very Mild Demented
Dataset
The dataset used is Falah/Alzheimer_MRI from the Hugging Face Datasets library. It consists of MRI images categorized into the four stages of Alzheimer's disease.
Model Architecture
The model architecture is based on ResNet50. The final fully connected layer is modified to output predictions for 4 classes.
Setup
To run the notebook locally, follow these steps:
- Clone the repository:
git clone https://github.com/your_username/alzheimer_mri_classification.git cd alzheimer_mri_classification
- Install the required dependencies:
pip install -r requirements.txt
- Open the notebook:
jupyter notebook fine-tuning.ipynb
Training
The notebook includes sections for:
- Loading and preprocessing the dataset
- Defining the model architecture
- Setting up the training loop with a learning rate scheduler and optimizer
- Training the model for a specified number of epochs
- Saving the trained model weights
Evaluation
The notebook includes a section for evaluating the trained model on the validation set. It calculates and prints the validation loss and accuracy.
Usage
Once trained, the model can be saved and used for inference on new MRI images. The trained model weights are saved as alzheimer_model_resnet50.pth.
Load the model architecture and weights
```python
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 4)
model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
model.eval()
```
Results
The model achieved an accuracy of 95.9375% on the validation set.
Contributing
Contributions are welcome! If you have any suggestions, bug reports, or feature requests, please open an issue or submit a pull request.