File size: 4,235 Bytes
ff297cd c9918c8 4cb82de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
---
license: apache-2.0
---
Classification with Neural Decision Forests
This is an example notebook for Keras sprint prepared by Hugging Face. Keras Sprint aims to reproduce Keras examples and build interactive demos to them. The markdown parts beginning with 🤗 and the following code snippets are the parts added by Hugging Face team to give you an example of how to host your model and build a demo.
Original Author of the Neural Decision Forests Example: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
## Introduction
This example provides an implementation of the [Deep Neural Decision Forest](https://ieeexplore.ieee.org/document/7410529) model introduced by P. Kontschieder et al. for structured data classification. It demonstrates how to build a stochastic and differentiable decision tree model, train it end-to-end, and unify decision trees with deep representation learning.
| Numerical Features | Categorical Features |
| :-- | :-- |
| age | workclass |
| education-num | education |
| capital-gain | marital-status |
| capital-loss | occupation |
| hours-per-week | relationship |
| | race |
| | gender |
| | native-country |
Dropped Feature: `fnlwgt`
Labelled Feature: `income_bracket`
The dataset comes in two parts meant for training and testing.
The training dataset has 32561 samples whereas the test dataset has 16282 samples.
## Training procedure
1. **Prepare Data:** Create tf.data.Dataset objects for training and validation-
We create an input function to read and parse the file, and convert features and labels into a `tf.data.Dataset` for training and validation. We also preprocess the input by mapping the target label to an index. We also use `layers.StringLookup` to prepare categorical data.
2. **Encode Features:** We encode the categorical and numerical features as follows:
Create a lookup to convert a string values to an integer indices. Since we are not using a mask token, nor expecting any out of vocabulary (oov) token, we set mask-token to None and num-oov-indices to 0.
**Categorical Features**: Create an embedding layer with the specified dimensions.
**Numerical Features**: Use `tf.expand_dims` on Numerical feature as it is.
3. **Create Model:**
## Deep Neural Decision Tree
A neural decision tree model has two sets of weights to learn. The first set is pi, which represents the probability distribution of the classes in the tree leaves. The second set is the weights of the routing layer decision-fn, which represents the probability of going to each leave. The forward pass of the model works as follows:
- The model expects input features as a single vector encoding all the features of an instance in the batch. This vector can be generated from a Convolution Neural Network (CNN) applied to images or dense transformations applied to structured data features.
- The model first applies a used_features_mask to randomly select a subset of input features to use.
- Then, the model computes the probabilities (mu) for the input instances to reach the tree leaves by iteratively performing a stochastic routing throughout the tree levels.
- Finally, the probabilities of reaching the leaves are combined by the class probabilities at the leaves to produce the final outputs.
4. **Compile, Train and Evaluate Model**:
- The loss function chosen was `SparseCategoricalCrossentropy`.
- The metric chosen for evaluating the model's performance was `SparseCategoricalAccuracy`.
- The optimizer chosen was `Adam` with a learning rate of 0.001.
- The batch-size chosen was 265 and the model was trained for 5 epochs.
- Finally the performance of the model was also evaluated on the test-dataset reaching an accuracy of ~85% on both Decision Model and Forest Model.
### Training hyperparameters
The following hyperparameters were used during training:
| Hyperparameters | Value |
| :-- | :-- |
| name | Adam |
| learning-rate | 0.01 |
| batch-size | 265 |
| num-epochs | 5 |
| num-trees | 10 |
| depth | 10 |
| used-features-rate | 1.0 |
| num-classes | 2 |
## Model Plot
<details>
<summary>View Model Plot</summary>
</details>
## Credits:
- HF Contribution: [Tarun R Jain](https://twitter.com/TRJ_0751) |