im2
commited on
Commit
•
f10799a
1
Parent(s):
92ebefc
yaml added
Browse files
README.md
CHANGED
@@ -1,5 +1,19 @@
|
|
1 |
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
**Model type**: Convolutional Neural Network (CNN)
|
5 |
**Model Architecture**: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer
|
@@ -8,43 +22,42 @@
|
|
8 |
|
9 |
## Model Description
|
10 |
|
11 |
-
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset
|
12 |
|
13 |
### Model Architecture:
|
14 |
- **Convolutional Layers**: 3 convolutional layers with ReLU activations.
|
15 |
-
- **Adaptive Pooling**: Adaptive Average Pooling
|
16 |
-
- **Fully Connected Layer**: The output from the convolutional layers is flattened and
|
17 |
|
18 |
## Training Data
|
19 |
|
20 |
-
The model was trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), which
|
21 |
|
22 |
### Data Preprocessing:
|
23 |
-
- **Data Augmentation**: Random rotations (up to 10 degrees) and translations (up to 10%
|
24 |
-
- **Normalization**:
|
25 |
-
- Mean: 0.5
|
26 |
-
- Standard Deviation: 0.5
|
27 |
|
28 |
## Intended Use
|
29 |
|
30 |
-
This model is
|
31 |
-
- Recognizing handwritten digits in
|
32 |
-
- Educational purposes to demonstrate
|
33 |
|
34 |
-
|
35 |
-
The model can be loaded using PyTorch, and an image can be classified by following this code snippet:
|
36 |
|
37 |
-
|
|
|
|
|
38 |
import torch
|
39 |
from torchvision import transforms
|
40 |
from PIL import Image
|
41 |
|
42 |
# Load the model
|
43 |
-
model =
|
44 |
model.load_state_dict(torch.load('mnist_classifier.pth'))
|
45 |
model.eval()
|
46 |
|
47 |
-
# Preprocess image
|
48 |
transform = transforms.Compose([
|
49 |
transforms.Resize((28, 28)),
|
50 |
transforms.ToTensor(),
|
@@ -54,28 +67,28 @@ transform = transforms.Compose([
|
|
54 |
img = Image.open('path_to_image').convert('L')
|
55 |
img_tensor = transform(img).unsqueeze(0)
|
56 |
|
57 |
-
#
|
58 |
with torch.no_grad():
|
59 |
output = model(img_tensor)
|
60 |
predicted_label = torch.argmax(output, dim=1).item()
|
61 |
|
62 |
print(f"Predicted Label: {predicted_label}")
|
63 |
-
|
64 |
|
65 |
## Evaluation Results
|
66 |
|
67 |
### Metrics:
|
68 |
-
The model
|
69 |
- **Accuracy**: ~98%
|
70 |
-
- **Loss**: Cross-entropy loss
|
71 |
|
72 |
-
### Noisy
|
73 |
-
The model was
|
74 |
|
75 |
## Limitations
|
76 |
|
77 |
-
- **Noisy Inputs**:
|
78 |
-
- **Generalization**:
|
79 |
|
80 |
## Training Details
|
81 |
|
@@ -85,12 +98,12 @@ The model was also tested on noisy digit images and successfully classified digi
|
|
85 |
- **Loss Function**: Cross-entropy Loss
|
86 |
- **Batch Size**: 32
|
87 |
- **Epochs**: 10
|
88 |
-
- **Data Augmentation**: Random rotations and translations
|
89 |
|
90 |
## Ethical Considerations
|
91 |
|
92 |
-
|
93 |
|
94 |
-
##
|
95 |
|
96 |
-
|
|
|
1 |
|
2 |
+
---
|
3 |
+
language: en
|
4 |
+
tags:
|
5 |
+
- mnist
|
6 |
+
- cnn
|
7 |
+
- pytorch
|
8 |
+
- image-classification
|
9 |
+
license: apache-2.0
|
10 |
+
datasets:
|
11 |
+
- mnist
|
12 |
+
metrics:
|
13 |
+
- accuracy
|
14 |
+
---
|
15 |
+
|
16 |
+
# MNIST Digit Classifier with Noise Reduction
|
17 |
|
18 |
**Model type**: Convolutional Neural Network (CNN)
|
19 |
**Model Architecture**: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer
|
|
|
22 |
|
23 |
## Model Description
|
24 |
|
25 |
+
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset and designed to classify handwritten digits from 0 to 9. The model uses data augmentation to improve its robustness, especially for noisy or rotated images. The preprocessing step includes Gaussian blur for noise reduction, making the model more resilient to outliers and noisy digit inputs.
|
26 |
|
27 |
### Model Architecture:
|
28 |
- **Convolutional Layers**: 3 convolutional layers with ReLU activations.
|
29 |
+
- **Adaptive Pooling**: Adaptive Average Pooling to ensure the model handles dynamic input sizes.
|
30 |
+
- **Fully Connected Layer**: The output from the convolutional layers is flattened and passed through a fully connected layer to predict the digit.
|
31 |
|
32 |
## Training Data
|
33 |
|
34 |
+
The model was trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), which consists of 60,000 training images and 10,000 test images of handwritten digits. The images are 28x28 pixels in grayscale.
|
35 |
|
36 |
### Data Preprocessing:
|
37 |
+
- **Data Augmentation**: Random rotations (up to 10 degrees) and random translations (up to 10%) were applied during training to make the model more robust to variations.
|
38 |
+
- **Normalization**: The pixel values were normalized to the range [-1, 1].
|
|
|
|
|
39 |
|
40 |
## Intended Use
|
41 |
|
42 |
+
This model is designed for:
|
43 |
+
- Recognizing handwritten digits in applications like form scanning, document analysis, or real-time digit detection.
|
44 |
+
- Educational purposes to demonstrate CNN-based image classification.
|
45 |
|
46 |
+
### How to use the model
|
|
|
47 |
|
48 |
+
You can load the trained model in PyTorch and use it to classify digit images as shown below:
|
49 |
+
|
50 |
+
```python
|
51 |
import torch
|
52 |
from torchvision import transforms
|
53 |
from PIL import Image
|
54 |
|
55 |
# Load the model
|
56 |
+
model = ImageClassifier()
|
57 |
model.load_state_dict(torch.load('mnist_classifier.pth'))
|
58 |
model.eval()
|
59 |
|
60 |
+
# Preprocess an input image
|
61 |
transform = transforms.Compose([
|
62 |
transforms.Resize((28, 28)),
|
63 |
transforms.ToTensor(),
|
|
|
67 |
img = Image.open('path_to_image').convert('L')
|
68 |
img_tensor = transform(img).unsqueeze(0)
|
69 |
|
70 |
+
# Perform inference
|
71 |
with torch.no_grad():
|
72 |
output = model(img_tensor)
|
73 |
predicted_label = torch.argmax(output, dim=1).item()
|
74 |
|
75 |
print(f"Predicted Label: {predicted_label}")
|
76 |
+
```
|
77 |
|
78 |
## Evaluation Results
|
79 |
|
80 |
### Metrics:
|
81 |
+
The model was evaluated on the MNIST test set, achieving the following results:
|
82 |
- **Accuracy**: ~98%
|
83 |
+
- **Loss**: Cross-entropy loss decreased to a value of approximately 0.15 after 10 epochs.
|
84 |
|
85 |
+
### Performance on Noisy Inputs:
|
86 |
+
The model was tested on noisy images (e.g., images with added noise or distortions), and the preprocessing steps (Gaussian blur, resizing) helped improve the model’s performance on such inputs.
|
87 |
|
88 |
## Limitations
|
89 |
|
90 |
+
- **Noisy Inputs**: Although preprocessing helps, very noisy or distorted inputs might still be challenging for the model.
|
91 |
+
- **Generalization**: This model is primarily trained on MNIST digits. It may not generalize well to digits from different writing styles or other number systems.
|
92 |
|
93 |
## Training Details
|
94 |
|
|
|
98 |
- **Loss Function**: Cross-entropy Loss
|
99 |
- **Batch Size**: 32
|
100 |
- **Epochs**: 10
|
101 |
+
- **Data Augmentation**: Random rotations and translations
|
102 |
|
103 |
## Ethical Considerations
|
104 |
|
105 |
+
There are no significant ethical concerns related to this model. However, users should be aware that the model is specifically trained on simple MNIST digits and may not perform well in more complex scenarios.
|
106 |
|
107 |
+
## Contact
|
108 |
|
109 |
+
For any questions or feedback, please reach out to the model author via [git@here.news].
|