Update README.md
Browse files
README.md
CHANGED
@@ -8,18 +8,101 @@ tags:
|
|
8 |
|
9 |
## Model description
|
10 |
|
11 |
-
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
## Training and evaluation data
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
## Training procedure
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
### Training hyperparameters
|
24 |
|
25 |
The following hyperparameters were used during training:
|
|
|
8 |
|
9 |
## Model description
|
10 |
|
11 |
+
This model is built using two important architectural components proposed by Bryan Lim et al. in [Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363) called GRN and VSN which are very useful for structured data classification tasks.
|
12 |
|
13 |
+
1. Gated Residual Networks(GRN) consist of skip connections and gating layers that facilitate information flow efficiently. They have the flexibility to apply non-linear processing only where needed.
|
14 |
+
2. Variable Selection Networks(VSN) help in carefully selecting the most important features from the input by getting rid of any unnecessary noisy inputs which could harm the model's performance.
|
15 |
|
16 |
+
**Note:** This model is not based on the whole TFT model but only uses the GRN and VSN components described in the mentioned paper demonstrating that GRN and VSNs on their own also can be very useful for structured data learning tasks.
|
17 |
+
|
18 |
+
## Intended uses
|
19 |
+
|
20 |
+
This model can be used for binary classification task to determine whether a person makes over $500K a year.
|
21 |
|
22 |
## Training and evaluation data
|
23 |
|
24 |
+
This model was trained using the [United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29) provided by the UCI Machine Learning Repository.
|
25 |
+
The dataset contains weighted census data extracted from 1994 and 1995 Current Population Surveys conducted by the US Census Bureau.
|
26 |
+
The dataset comprises of ~300K samples with 41 input features containing 7 numerical features and 34 categorical features:
|
27 |
+
|
28 |
+
| Numerical Features | Categorical Features |
|
29 |
+
| :-- | :-- |
|
30 |
+
| age | class of worker |
|
31 |
+
| wage per hour | industry code |
|
32 |
+
| capital gains | occupation code |
|
33 |
+
| capital losses | adjusted gross income |
|
34 |
+
| dividends from stocks | education |
|
35 |
+
| num persons worked for employer | veterans benefits |
|
36 |
+
| weeks worked in year | enrolled in edu inst last wk
|
37 |
+
|| marital status |
|
38 |
+
|| major industry code |
|
39 |
+
|| major occupation code |
|
40 |
+
|| mace |
|
41 |
+
|| hispanic Origin |
|
42 |
+
|| sex |
|
43 |
+
|| member of a labor union |
|
44 |
+
|| reason for unemployment |
|
45 |
+
|| full or part time employment stat |
|
46 |
+
|| federal income tax liability |
|
47 |
+
|| tax filer status |
|
48 |
+
|| region of previous residence |
|
49 |
+
|| state of previous residence |
|
50 |
+
|| detailed household and family stat |
|
51 |
+
|| detailed household summary in household |
|
52 |
+
|| instance weight |
|
53 |
+
|| migration code-change in msa |
|
54 |
+
|| migration code-change in reg |
|
55 |
+
|| migration code-move within reg |
|
56 |
+
|| live in this house 1 year ago |
|
57 |
+
|| migration prev res in sunbelt |
|
58 |
+
|| family members under 18 |
|
59 |
+
|| total person earnings |
|
60 |
+
|| country of birth father |
|
61 |
+
|| country of birth mother |
|
62 |
+
|| country of birth self |
|
63 |
+
|| citizenship |
|
64 |
+
|| total person income |
|
65 |
+
|| own business or self employed |
|
66 |
+
|| taxable income amount |
|
67 |
+
|| fill inc questionnaire for veteran's admin |
|
68 |
+
|
69 |
|
70 |
## Training procedure
|
71 |
|
72 |
+
0. **Prepare Data:** Download the data and convert the target column *income_level* from string to integer and finally split the data into train and validation.
|
73 |
+
|
74 |
+
1. **Prepare tf.data.Dataset:** Train and validation datasets created using Step 0 are passed to a function that converts the features and labels into a tf.data.Dataset for training and evaluation.
|
75 |
+
|
76 |
+
2. **Define logic for Encoding input features:** All features are encoded while also ensuring that they all have the same dimensionality.
|
77 |
+
|
78 |
+
- **Categorical Features:** are encoded using *Embedding* layer provided by Keras with output dimension of embedding equal to *encoding_size*
|
79 |
+
|
80 |
+
- **Numerical Features:** are projected into a *encoding_size* dimensional vector by applying a linear transformation using *Dense* layer provided by Keras
|
81 |
+
|
82 |
+
3. **Implement the Gated Linear Unit (GLU):** consists of two Dense layers where the last last dense layer has a sigmoid activation. GLUs help in suppressing inputs that are not useful for a given task.
|
83 |
+
|
84 |
+
4. **Implement the Gated Residual Network:**
|
85 |
+
- Applies Non-linear ELU tranformation on its inputs
|
86 |
+
- Applies linear transformation followed by dropout
|
87 |
+
- Applies GLU and adds the original inputs to the output of the GLU to perform skip (residual) connection
|
88 |
+
- Applies layer normalization and produces the output
|
89 |
+
|
90 |
+
5. **Implement the Variable Selection Network:**
|
91 |
+
- Applies a Gated Residual Network (GRN) which was defined in step 4 to each feature individually.
|
92 |
+
- Applies a GRN for the concatenation of all features followed by a softmax to produce feature weights
|
93 |
+
- Produces a weighted sum of the output of the individual GRN
|
94 |
+
|
95 |
+
6. **Create Model:**
|
96 |
+
- The model will have input layers corresponding to both numerical and categorical features of the given dataset
|
97 |
+
- The features received by the input layers are then encoded using the encoding logic defined in Step 2.
|
98 |
+
- The encoded features pass through the Variable Selection Network(VSN)
|
99 |
+
- The output produced by the VSN are passed through a final *Dense* layer with sigmoid activation to produce the final output of the model
|
100 |
+
|
101 |
+
7. **Compile, Train and Evaluate Model**: The model is compiled using Adam optimizer and since the model is meant to binary classification, the loss function chosen is Binary Cross Entropy.
|
102 |
+
The model is trained for 20 epochs and batch_size of 265 with a callback for early stopping.
|
103 |
+
The model performance is evaluated based on the accuracy and loss being observed on the validation set.
|
104 |
+
|
105 |
+
|
106 |
### Training hyperparameters
|
107 |
|
108 |
The following hyperparameters were used during training:
|