Robin Tully commited on
Commit
9586c0c
·
1 Parent(s): 20d9458

training notebook

Browse files
Files changed (6) hide show
  1. .gitignore +167 -0
  2. .python-version +1 -0
  3. README.md +65 -1
  4. model_training.ipynb +1269 -0
  5. pyproject.toml +27 -0
  6. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .ruff_cache
3
+ catboost_info/
4
+ modernBERT-content-regression/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115
+ .pdm.toml
116
+ .pdm-python
117
+ .pdm-build/
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -1 +1,65 @@
1
- # modernbert-content-regression
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### What is this?
3
+
4
+ This is an exploration of using modernBERT for the text regression task of predicting engagement metrics for text content. In this case, we are predicting the clickthrough rate (CTR) of email text content.
5
+
6
+ We will be exploring hyperparameter tuning of modernBert; and how to use it for regression, as well as comparing the results to a benchmark model.
7
+
8
+ This type of task if difficult, we can remember the quote
9
+ > “Half my advertising is wasted; the trouble is, I don't know which half”
10
+ > -John Wanamaker
11
+
12
+ We are also excluding other relevant factors such as the time of day the email is sent, the day of the week, the recipient, etc in this experiment.
13
+
14
+ This work is indebted to the work of many community members and blog posts.
15
+ - [ModernBERT Aanouncement](https://huggingface.co/blog/modernbert)
16
+ - [Fine-tune classifier with ModernBERT in 2025](https://www.philschmid.de/fine-tune-modern-bert-in-2025)
17
+ - [How to set up Trainer for a regression](https://discuss.huggingface.co/t/how-to-set-up-trainer-for-a-regression/12994)
18
+
19
+ ### Our dataset
20
+ We will be using a dataset of 548 emails where we have the text of the email `text` and the CTR we are trying to predict `labels`.
21
+
22
+ We look forward in the improvements of ModernBERT to fine-tune models specifically for each potential users email dataset. The variability of email data, as well as the small size of the dataset pose an interesting regression challenge.
23
+
24
+ ### Benchmarking
25
+ We will start by using the Catboost library as a simple benchmark for text regression. For both the benchmark and the ModernBert run, we are using 'rmse' as the metric.
26
+ We recieve the following results:
27
+ | Metric | Value |
28
+ |--------|------------------|
29
+ | MSE | 2.552100633998035 |
30
+ | RMSE | 1.5975295408843102 |
31
+ | MAE | 1.1439370629666958 |
32
+ | R² | 0.30127932054387174 |
33
+ | SMAPE | 37.63064694052479 |
34
+
35
+ ## Fitting the Modern Bert Model
36
+
37
+ ### Install dependencies and activate venv
38
+ ```bash
39
+ uv sync
40
+ source .venv/bin/activate
41
+ ```
42
+ the following values need to be defined in the .env file
43
+ - `HUGGINGFACE_TOKEN`
44
+
45
+ ### Run notebook for model fitting
46
+
47
+ ```bash
48
+ uv run --with jupyter jupyter lab
49
+ ```
50
+
51
+ ### ModernBert Model Performance
52
+ After running hyperparameter tuning for ModernBERT, we get the following results:
53
+
54
+ | Metric | Value |
55
+ |--------|------------------|
56
+ | MSE | 2.4624056816101074 |
57
+ | RMSE | 1.5692054300218654 |
58
+ | MAE | 1.182181715965271 |
59
+ | R² | 0.325836181640625 |
60
+ | SMAPE | 56.61447048187256 |
61
+
62
+ We see improvements in all metrics except for SMAPE. We believe that ModernBERT would scale even better with a larger dataset; as 500 example is very low for fine-tuning and are thus happy with the performance of this evaluation.
63
+
64
+ ## Conclusion
65
+ We see that ModernBERT is a powerful model for text regression. We believe that with a larger dataset, we would see even better results. We are excited to see the future of ModernBERT and how it will be used for text regression.
model_training.ipynb ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "51a9bb64-969a-4fc0-aa76-4bd42b08c21a",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "True"
13
+ ]
14
+ },
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "output_type": "execute_result"
18
+ }
19
+ ],
20
+ "source": [
21
+ "import os\n",
22
+ "\n",
23
+ "import numpy as np\n",
24
+ "import pandas as pd\n",
25
+ "from catboost import CatBoostRegressor, Pool\n",
26
+ "from datasets import load_dataset\n",
27
+ "from dotenv import load_dotenv\n",
28
+ "from huggingface_hub import HfFolder, login\n",
29
+ "from sklearn.metrics import (\n",
30
+ " mean_absolute_error,\n",
31
+ " mean_squared_error,\n",
32
+ " r2_score,\n",
33
+ " root_mean_squared_error,\n",
34
+ ")\n",
35
+ "from transformers import (\n",
36
+ " AutoModelForSequenceClassification,\n",
37
+ " AutoTokenizer,\n",
38
+ " Trainer,\n",
39
+ " TrainingArguments,\n",
40
+ ")\n",
41
+ "\n",
42
+ "load_dotenv()"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 2,
48
+ "id": "9b53c782-5dbb-4dd6-b541-8e4fab3f3ddf",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "login(token=os.getenv(\"HUGGINGFACE_API_KEY\"))"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "id": "87f58ec1-231d-4c17-8af2-c366af55e375",
58
+ "metadata": {},
59
+ "source": [
60
+ "### Dataset prep"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 3,
66
+ "id": "0ea777d8-988b-421a-8d76-b0f9256ab61b",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "raw_dataset = load_dataset(\"Forecast-ing/email-clickthrough\")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 4,
76
+ "id": "0c8441e7-1606-4c61-8f6c-34ebe1e107c0",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "raw_dataset = raw_dataset.rename_column(\"label\", \"labels\")"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 5,
86
+ "id": "201b806d-6c94-4053-98b6-4d22dcdda08a",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "text/plain": [
92
+ "3292"
93
+ ]
94
+ },
95
+ "execution_count": 5,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "raw_dataset[\"train\"].to_pandas()[\"text\"].str.len().max()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 6,
107
+ "id": "cbd4d6ec-b293-49f1-925f-239549dab61e",
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "data": {
112
+ "text/plain": [
113
+ "0.2427007299270073"
114
+ ]
115
+ },
116
+ "execution_count": 6,
117
+ "metadata": {},
118
+ "output_type": "execute_result"
119
+ }
120
+ ],
121
+ "source": [
122
+ "(raw_dataset[\"train\"].to_pandas()[\"text\"].str.len() > 2048).mean()"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 7,
128
+ "id": "1101c038-3f83-4055-b938-3861ac43cf8f",
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "data": {
133
+ "text/plain": [
134
+ "count 548.000000\n",
135
+ "mean 2.879635\n",
136
+ "std 2.423870\n",
137
+ "min 0.450000\n",
138
+ "25% 1.510000\n",
139
+ "50% 2.025000\n",
140
+ "75% 3.267500\n",
141
+ "max 25.370000\n",
142
+ "Name: labels, dtype: float64"
143
+ ]
144
+ },
145
+ "execution_count": 7,
146
+ "metadata": {},
147
+ "output_type": "execute_result"
148
+ }
149
+ ],
150
+ "source": [
151
+ "raw_dataset[\"train\"].to_pandas()[\"labels\"].describe()"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 8,
157
+ "id": "7546577f-4d7f-41b5-a68f-095fc0e8eec4",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "raw_dataset = raw_dataset[\"train\"].train_test_split(test_size=0.1, seed=1)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 9,
167
+ "id": "a0c28ab6-20f8-47dc-8ee9-179ea15830e0",
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "name": "stdout",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "Train dataset size: 493\n",
175
+ "Test dataset size: 55\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "print(f\"Train dataset size: {len(raw_dataset['train'])}\")\n",
181
+ "print(f\"Test dataset size: {len(raw_dataset['test'])}\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "markdown",
186
+ "id": "a2c3e7c3-e31d-4d8f-8d7d-e62050a9ae9d",
187
+ "metadata": {},
188
+ "source": [
189
+ "### Catboost Benchmark"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 10,
195
+ "id": "01aaea26-e1df-4493-b9cd-732c3b7a76a9",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "catboost_train = raw_dataset[\"train\"].to_pandas()\n",
200
+ "catboost_test = raw_dataset[\"test\"].to_pandas()"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 11,
206
+ "id": "0243e07d-69ba-41e5-b54d-d0f4988bbf9f",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "text_columns = [\"text\"]\n",
211
+ "label = \"labels\""
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 12,
217
+ "id": "ba17040c-5882-47dc-a8af-a7557356840f",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "train_pool = Pool(\n",
222
+ " data=catboost_train[text_columns],\n",
223
+ " label=catboost_train[label],\n",
224
+ " text_features=text_columns,\n",
225
+ ")\n",
226
+ "test_pool = Pool(\n",
227
+ " data=catboost_test[text_columns],\n",
228
+ " label=catboost_test[label],\n",
229
+ " text_features=text_columns,\n",
230
+ ")"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 13,
236
+ "id": "d8b3768e-6f30-41ce-a209-bd915a997d8a",
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "Learning rate set to 0.045569\n",
244
+ "0:\tlearn: 2.4332854\ttest: 1.8670741\tbest: 1.8670741 (0)\ttotal: 60.5ms\tremaining: 1m\n",
245
+ "100:\tlearn: 1.4972558\ttest: 1.6247590\tbest: 1.6048404 (59)\ttotal: 2.5s\tremaining: 22.2s\n",
246
+ "200:\tlearn: 1.1104040\ttest: 1.6015944\tbest: 1.5975296 (197)\ttotal: 4.91s\tremaining: 19.5s\n",
247
+ "300:\tlearn: 0.8568033\ttest: 1.6102309\tbest: 1.5975296 (197)\ttotal: 7.33s\tremaining: 17s\n",
248
+ "400:\tlearn: 0.7096792\ttest: 1.6090190\tbest: 1.5975296 (197)\ttotal: 9.72s\tremaining: 14.5s\n",
249
+ "500:\tlearn: 0.6056532\ttest: 1.6083240\tbest: 1.5975296 (197)\ttotal: 12.1s\tremaining: 12s\n",
250
+ "600:\tlearn: 0.5298016\ttest: 1.6175366\tbest: 1.5975296 (197)\ttotal: 14.5s\tremaining: 9.64s\n",
251
+ "700:\tlearn: 0.4701467\ttest: 1.6262668\tbest: 1.5975296 (197)\ttotal: 16.9s\tremaining: 7.23s\n",
252
+ "800:\tlearn: 0.4233732\ttest: 1.6199203\tbest: 1.5975296 (197)\ttotal: 19.4s\tremaining: 4.81s\n",
253
+ "900:\tlearn: 0.3837074\ttest: 1.6104091\tbest: 1.5975296 (197)\ttotal: 21.8s\tremaining: 2.39s\n",
254
+ "999:\tlearn: 0.3501113\ttest: 1.6131207\tbest: 1.5975296 (197)\ttotal: 24.2s\tremaining: 0us\n",
255
+ "\n",
256
+ "bestTest = 1.597529566\n",
257
+ "bestIteration = 197\n",
258
+ "\n",
259
+ "Shrink model to first 198 iterations.\n"
260
+ ]
261
+ },
262
+ {
263
+ "data": {
264
+ "text/plain": [
265
+ "<catboost.core.CatBoostRegressor at 0x7fb1061c5bb0>"
266
+ ]
267
+ },
268
+ "execution_count": 13,
269
+ "metadata": {},
270
+ "output_type": "execute_result"
271
+ }
272
+ ],
273
+ "source": [
274
+ "model = CatBoostRegressor(loss_function=\"RMSE\", verbose=100)\n",
275
+ "\n",
276
+ "model.fit(train_pool, eval_set=test_pool)"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 14,
282
+ "id": "837b22a8-241d-49b3-a1ae-915893121319",
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "y_pred = model.predict(test_pool)\n",
287
+ "y_val = catboost_test[label]"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 15,
293
+ "id": "478521bf-85be-49a1-8461-020587f146d2",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "def smape(y_true, y_pred):\n",
298
+ " return 100 * np.mean(\n",
299
+ " 2 * np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred))\n",
300
+ " )\n",
301
+ "\n",
302
+ "\n",
303
+ "def calculate_metrics(y_val, y_pred):\n",
304
+ " mse = mean_squared_error(y_val, y_pred)\n",
305
+ " rmse = np.sqrt(mse)\n",
306
+ " mae = mean_absolute_error(y_val, y_pred)\n",
307
+ " r2 = r2_score(y_val, y_pred)\n",
308
+ " smape_value = smape(y_val, y_pred)\n",
309
+ " return {\n",
310
+ " \"mse\": mse,\n",
311
+ " \"rmse\": rmse,\n",
312
+ " \"mae\": mae,\n",
313
+ " \"r2\": r2,\n",
314
+ " \"smape\": smape_value,\n",
315
+ " }"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 16,
321
+ "id": "91ea95e2-1818-45a6-8725-3a1353cb5b97",
322
+ "metadata": {},
323
+ "outputs": [],
324
+ "source": [
325
+ "catboost_metrics = calculate_metrics(y_val, y_pred)"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 17,
331
+ "id": "e28e359e-f69c-4ee8-9bbd-e7afafafcd26",
332
+ "metadata": {},
333
+ "outputs": [
334
+ {
335
+ "data": {
336
+ "text/plain": [
337
+ "{'mse': 2.552100633998035,\n",
338
+ " 'rmse': 1.5975295408843102,\n",
339
+ " 'mae': 1.1439370629666958,\n",
340
+ " 'r2': 0.30127932054387174,\n",
341
+ " 'smape': 37.63064694052479}"
342
+ ]
343
+ },
344
+ "execution_count": 17,
345
+ "metadata": {},
346
+ "output_type": "execute_result"
347
+ }
348
+ ],
349
+ "source": [
350
+ "catboost_metrics"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "id": "7afac97a-1e69-47e4-9ecd-ece7ebe7b48f",
356
+ "metadata": {},
357
+ "source": [
358
+ "### Fine Tuning Modern Bert"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": 18,
364
+ "id": "031df047-2c18-4ec9-a498-596a7cf965b7",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "model_id = \"answerdotai/ModernBERT-base\"\n",
369
+ "\n",
370
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
371
+ "tokenizer.model_max_length = 2048\n",
372
+ "\n",
373
+ "def tokenize(batch):\n",
374
+ " return tokenizer(\n",
375
+ " batch[\"text\"], padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n",
376
+ " )"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 19,
382
+ "id": "bbb711e8-8da6-401b-a2bd-c1637731f6c9",
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": [
386
+ "tokenized_dataset = raw_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": 20,
392
+ "id": "ca4c2942-bf9e-4579-82b9-654fbee85b54",
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "def model_init(trial):\n",
397
+ " model = AutoModelForSequenceClassification.from_pretrained(\n",
398
+ " model_id, num_labels=1, ignore_mismatched_sizes=True, problem_type=\"regression\"\n",
399
+ " )\n",
400
+ " return model"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 21,
406
+ "id": "a2005499-e139-4151-9ebe-2759710149b1",
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": [
410
+ "def gen_training_args(additional_args={}):\n",
411
+ " default_args = {\n",
412
+ " \"output_dir\": \"./modernBERT-content-regression\",\n",
413
+ " \"per_device_eval_batch_size\": 4,\n",
414
+ " \"per_device_train_batch_size\": 4,\n",
415
+ " \"num_train_epochs\": 5,\n",
416
+ " \"bf16\": True, # bfloat16 training\n",
417
+ " \"optim\": \"adamw_torch_fused\", # improved optimizer\n",
418
+ " \"logging_strategy\": \"steps\",\n",
419
+ " \"logging_steps\": 1,\n",
420
+ " \"evaluation_strategy\": \"epoch\",\n",
421
+ " \"save_strategy\": \"epoch\",\n",
422
+ " \"save_total_limit\": 1,\n",
423
+ " \"metric_for_best_model\": \"rmse\",\n",
424
+ " \"greater_is_better\": False,\n",
425
+ " \"report_to\": \"tensorboard\",\n",
426
+ " \"push_to_hub\": True,\n",
427
+ " \"hub_private_repo\": True,\n",
428
+ " \"hub_strategy\": \"every_save\",\n",
429
+ " \"hub_token\": HfFolder.get_token(),\n",
430
+ " }\n",
431
+ " training_args = TrainingArguments(**default_args, **additional_args)\n",
432
+ " return training_args"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": 22,
438
+ "id": "e7e0ee17-9a1c-4789-8a4b-d2469a88837b",
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "def compute_metrics_for_regression(eval_pred):\n",
443
+ " predictions, labels = eval_pred\n",
444
+ " predictions = predictions.reshape(-1, 1)\n",
445
+ " results = calculate_metrics(labels, predictions)\n",
446
+ " return results\n"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": 23,
452
+ "id": "10981164-fffc-4128-89ae-41c9238074cd",
453
+ "metadata": {},
454
+ "outputs": [
455
+ {
456
+ "name": "stderr",
457
+ "output_type": "stream",
458
+ "text": [
459
+ "/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
460
+ " warnings.warn(\n",
461
+ "/tmp/ipykernel_22314/2727960756.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
462
+ " hp_trainer = Trainer(\n",
463
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
464
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
465
+ ]
466
+ }
467
+ ],
468
+ "source": [
469
+ "hp_trainer = Trainer(\n",
470
+ " model=None,\n",
471
+ " args=gen_training_args(),\n",
472
+ " train_dataset=tokenized_dataset[\"train\"],\n",
473
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
474
+ " tokenizer=tokenizer,\n",
475
+ " compute_metrics=compute_metrics_for_regression,\n",
476
+ " model_init=model_init,\n",
477
+ ")"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": 24,
483
+ "id": "261a4988-e5f9-469f-b8ae-ef51ae6a95df",
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "def optuna_hp_space(trial):\n",
488
+ " return {\n",
489
+ " \"learning_rate\": trial.suggest_float(\"learning_rate\", 5e-7, 5e-5, log=True),\n",
490
+ " }"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": 25,
496
+ "id": "7e7fb4f6-4eb4-4a4c-931d-a97c1614a5b9",
497
+ "metadata": {},
498
+ "outputs": [
499
+ {
500
+ "name": "stderr",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "[I 2025-01-09 12:16:25,726] A new study created in memory with name: no-name-2f3f9073-d130-4bb1-9447-7262f2b7bd75\n",
504
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
505
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
506
+ ]
507
+ },
508
+ {
509
+ "data": {
510
+ "text/html": [
511
+ "\n",
512
+ " <div>\n",
513
+ " \n",
514
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
515
+ " [620/620 03:27, Epoch 5/5]\n",
516
+ " </div>\n",
517
+ " <table border=\"1\" class=\"dataframe\">\n",
518
+ " <thead>\n",
519
+ " <tr style=\"text-align: left;\">\n",
520
+ " <th>Epoch</th>\n",
521
+ " <th>Training Loss</th>\n",
522
+ " <th>Validation Loss</th>\n",
523
+ " <th>Mse</th>\n",
524
+ " <th>Rmse</th>\n",
525
+ " <th>Mae</th>\n",
526
+ " <th>R2</th>\n",
527
+ " <th>Smape</th>\n",
528
+ " </tr>\n",
529
+ " </thead>\n",
530
+ " <tbody>\n",
531
+ " <tr>\n",
532
+ " <td>1</td>\n",
533
+ " <td>0.238000</td>\n",
534
+ " <td>4.573008</td>\n",
535
+ " <td>4.573008</td>\n",
536
+ " <td>2.138459</td>\n",
537
+ " <td>1.324540</td>\n",
538
+ " <td>-0.252010</td>\n",
539
+ " <td>54.242009</td>\n",
540
+ " </tr>\n",
541
+ " <tr>\n",
542
+ " <td>2</td>\n",
543
+ " <td>3.768500</td>\n",
544
+ " <td>4.093452</td>\n",
545
+ " <td>4.093452</td>\n",
546
+ " <td>2.023228</td>\n",
547
+ " <td>1.458057</td>\n",
548
+ " <td>-0.120716</td>\n",
549
+ " <td>53.770840</td>\n",
550
+ " </tr>\n",
551
+ " <tr>\n",
552
+ " <td>3</td>\n",
553
+ " <td>27.661000</td>\n",
554
+ " <td>3.361875</td>\n",
555
+ " <td>3.361874</td>\n",
556
+ " <td>1.833541</td>\n",
557
+ " <td>1.126670</td>\n",
558
+ " <td>0.079577</td>\n",
559
+ " <td>52.641284</td>\n",
560
+ " </tr>\n",
561
+ " <tr>\n",
562
+ " <td>4</td>\n",
563
+ " <td>0.092300</td>\n",
564
+ " <td>2.759459</td>\n",
565
+ " <td>2.759459</td>\n",
566
+ " <td>1.661162</td>\n",
567
+ " <td>1.040074</td>\n",
568
+ " <td>0.244508</td>\n",
569
+ " <td>53.009331</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <td>5</td>\n",
573
+ " <td>0.020300</td>\n",
574
+ " <td>2.733250</td>\n",
575
+ " <td>2.733250</td>\n",
576
+ " <td>1.653254</td>\n",
577
+ " <td>1.078653</td>\n",
578
+ " <td>0.251684</td>\n",
579
+ " <td>54.187167</td>\n",
580
+ " </tr>\n",
581
+ " </tbody>\n",
582
+ "</table><p>"
583
+ ],
584
+ "text/plain": [
585
+ "<IPython.core.display.HTML object>"
586
+ ]
587
+ },
588
+ "metadata": {},
589
+ "output_type": "display_data"
590
+ },
591
+ {
592
+ "name": "stderr",
593
+ "output_type": "stream",
594
+ "text": [
595
+ "[I 2025-01-09 12:19:57,000] Trial 0 finished with value: 1.6532543369745685 and parameters: {'learning_rate': 1.9437267223645173e-05}. Best is trial 0 with value: 1.6532543369745685.\n",
596
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
597
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
598
+ ]
599
+ },
600
+ {
601
+ "data": {
602
+ "text/html": [
603
+ "\n",
604
+ " <div>\n",
605
+ " \n",
606
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
607
+ " [620/620 03:30, Epoch 5/5]\n",
608
+ " </div>\n",
609
+ " <table border=\"1\" class=\"dataframe\">\n",
610
+ " <thead>\n",
611
+ " <tr style=\"text-align: left;\">\n",
612
+ " <th>Epoch</th>\n",
613
+ " <th>Training Loss</th>\n",
614
+ " <th>Validation Loss</th>\n",
615
+ " <th>Mse</th>\n",
616
+ " <th>Rmse</th>\n",
617
+ " <th>Mae</th>\n",
618
+ " <th>R2</th>\n",
619
+ " <th>Smape</th>\n",
620
+ " </tr>\n",
621
+ " </thead>\n",
622
+ " <tbody>\n",
623
+ " <tr>\n",
624
+ " <td>1</td>\n",
625
+ " <td>0.033500</td>\n",
626
+ " <td>3.730757</td>\n",
627
+ " <td>3.730757</td>\n",
628
+ " <td>1.931517</td>\n",
629
+ " <td>1.167591</td>\n",
630
+ " <td>-0.021416</td>\n",
631
+ " <td>46.438679</td>\n",
632
+ " </tr>\n",
633
+ " <tr>\n",
634
+ " <td>2</td>\n",
635
+ " <td>3.021100</td>\n",
636
+ " <td>3.532418</td>\n",
637
+ " <td>3.532420</td>\n",
638
+ " <td>1.879473</td>\n",
639
+ " <td>1.171051</td>\n",
640
+ " <td>0.032885</td>\n",
641
+ " <td>48.273236</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <td>3</td>\n",
645
+ " <td>32.454400</td>\n",
646
+ " <td>3.670944</td>\n",
647
+ " <td>3.670944</td>\n",
648
+ " <td>1.915971</td>\n",
649
+ " <td>1.159171</td>\n",
650
+ " <td>-0.005041</td>\n",
651
+ " <td>48.529482</td>\n",
652
+ " </tr>\n",
653
+ " <tr>\n",
654
+ " <td>4</td>\n",
655
+ " <td>0.074300</td>\n",
656
+ " <td>3.690546</td>\n",
657
+ " <td>3.690546</td>\n",
658
+ " <td>1.921079</td>\n",
659
+ " <td>1.179955</td>\n",
660
+ " <td>-0.010407</td>\n",
661
+ " <td>49.107727</td>\n",
662
+ " </tr>\n",
663
+ " <tr>\n",
664
+ " <td>5</td>\n",
665
+ " <td>0.098800</td>\n",
666
+ " <td>3.677439</td>\n",
667
+ " <td>3.677439</td>\n",
668
+ " <td>1.917665</td>\n",
669
+ " <td>1.188619</td>\n",
670
+ " <td>-0.006819</td>\n",
671
+ " <td>49.251461</td>\n",
672
+ " </tr>\n",
673
+ " </tbody>\n",
674
+ "</table><p>"
675
+ ],
676
+ "text/plain": [
677
+ "<IPython.core.display.HTML object>"
678
+ ]
679
+ },
680
+ "metadata": {},
681
+ "output_type": "display_data"
682
+ },
683
+ {
684
+ "name": "stderr",
685
+ "output_type": "stream",
686
+ "text": [
687
+ "[I 2025-01-09 12:23:31,566] Trial 1 finished with value: 1.91766510403085 and parameters: {'learning_rate': 1.5810058165067856e-06}. Best is trial 0 with value: 1.6532543369745685.\n",
688
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
689
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
690
+ ]
691
+ },
692
+ {
693
+ "data": {
694
+ "text/html": [
695
+ "\n",
696
+ " <div>\n",
697
+ " \n",
698
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
699
+ " [620/620 03:28, Epoch 5/5]\n",
700
+ " </div>\n",
701
+ " <table border=\"1\" class=\"dataframe\">\n",
702
+ " <thead>\n",
703
+ " <tr style=\"text-align: left;\">\n",
704
+ " <th>Epoch</th>\n",
705
+ " <th>Training Loss</th>\n",
706
+ " <th>Validation Loss</th>\n",
707
+ " <th>Mse</th>\n",
708
+ " <th>Rmse</th>\n",
709
+ " <th>Mae</th>\n",
710
+ " <th>R2</th>\n",
711
+ " <th>Smape</th>\n",
712
+ " </tr>\n",
713
+ " </thead>\n",
714
+ " <tbody>\n",
715
+ " <tr>\n",
716
+ " <td>1</td>\n",
717
+ " <td>0.311500</td>\n",
718
+ " <td>4.090590</td>\n",
719
+ " <td>4.090590</td>\n",
720
+ " <td>2.022521</td>\n",
721
+ " <td>1.229977</td>\n",
722
+ " <td>-0.119932</td>\n",
723
+ " <td>50.514507</td>\n",
724
+ " </tr>\n",
725
+ " <tr>\n",
726
+ " <td>2</td>\n",
727
+ " <td>2.652800</td>\n",
728
+ " <td>4.852318</td>\n",
729
+ " <td>4.852319</td>\n",
730
+ " <td>2.202798</td>\n",
731
+ " <td>1.465739</td>\n",
732
+ " <td>-0.328480</td>\n",
733
+ " <td>54.715651</td>\n",
734
+ " </tr>\n",
735
+ " <tr>\n",
736
+ " <td>3</td>\n",
737
+ " <td>24.626400</td>\n",
738
+ " <td>3.331610</td>\n",
739
+ " <td>3.331610</td>\n",
740
+ " <td>1.825270</td>\n",
741
+ " <td>1.143937</td>\n",
742
+ " <td>0.087863</td>\n",
743
+ " <td>51.898420</td>\n",
744
+ " </tr>\n",
745
+ " <tr>\n",
746
+ " <td>4</td>\n",
747
+ " <td>0.289600</td>\n",
748
+ " <td>2.353773</td>\n",
749
+ " <td>2.353773</td>\n",
750
+ " <td>1.534201</td>\n",
751
+ " <td>1.079125</td>\n",
752
+ " <td>0.355578</td>\n",
753
+ " <td>55.779856</td>\n",
754
+ " </tr>\n",
755
+ " <tr>\n",
756
+ " <td>5</td>\n",
757
+ " <td>0.001400</td>\n",
758
+ " <td>2.629261</td>\n",
759
+ " <td>2.629261</td>\n",
760
+ " <td>1.621500</td>\n",
761
+ " <td>1.166006</td>\n",
762
+ " <td>0.280154</td>\n",
763
+ " <td>57.977718</td>\n",
764
+ " </tr>\n",
765
+ " </tbody>\n",
766
+ "</table><p>"
767
+ ],
768
+ "text/plain": [
769
+ "<IPython.core.display.HTML object>"
770
+ ]
771
+ },
772
+ "metadata": {},
773
+ "output_type": "display_data"
774
+ },
775
+ {
776
+ "name": "stderr",
777
+ "output_type": "stream",
778
+ "text": [
779
+ "[I 2025-01-09 12:27:05,020] Trial 2 finished with value: 1.6214995462309338 and parameters: {'learning_rate': 2.479942619764035e-05}. Best is trial 2 with value: 1.6214995462309338.\n",
780
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
781
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
782
+ ]
783
+ },
784
+ {
785
+ "data": {
786
+ "text/html": [
787
+ "\n",
788
+ " <div>\n",
789
+ " \n",
790
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
791
+ " [620/620 03:25, Epoch 5/5]\n",
792
+ " </div>\n",
793
+ " <table border=\"1\" class=\"dataframe\">\n",
794
+ " <thead>\n",
795
+ " <tr style=\"text-align: left;\">\n",
796
+ " <th>Epoch</th>\n",
797
+ " <th>Training Loss</th>\n",
798
+ " <th>Validation Loss</th>\n",
799
+ " <th>Mse</th>\n",
800
+ " <th>Rmse</th>\n",
801
+ " <th>Mae</th>\n",
802
+ " <th>R2</th>\n",
803
+ " <th>Smape</th>\n",
804
+ " </tr>\n",
805
+ " </thead>\n",
806
+ " <tbody>\n",
807
+ " <tr>\n",
808
+ " <td>1</td>\n",
809
+ " <td>0.008000</td>\n",
810
+ " <td>3.590378</td>\n",
811
+ " <td>3.590379</td>\n",
812
+ " <td>1.894829</td>\n",
813
+ " <td>1.149898</td>\n",
814
+ " <td>0.017017</td>\n",
815
+ " <td>46.445611</td>\n",
816
+ " </tr>\n",
817
+ " <tr>\n",
818
+ " <td>2</td>\n",
819
+ " <td>2.704000</td>\n",
820
+ " <td>3.476464</td>\n",
821
+ " <td>3.476464</td>\n",
822
+ " <td>1.864528</td>\n",
823
+ " <td>1.125000</td>\n",
824
+ " <td>0.048205</td>\n",
825
+ " <td>47.319812</td>\n",
826
+ " </tr>\n",
827
+ " <tr>\n",
828
+ " <td>3</td>\n",
829
+ " <td>32.099300</td>\n",
830
+ " <td>3.543669</td>\n",
831
+ " <td>3.543668</td>\n",
832
+ " <td>1.882463</td>\n",
833
+ " <td>1.123369</td>\n",
834
+ " <td>0.029805</td>\n",
835
+ " <td>47.717217</td>\n",
836
+ " </tr>\n",
837
+ " <tr>\n",
838
+ " <td>4</td>\n",
839
+ " <td>0.058200</td>\n",
840
+ " <td>3.590872</td>\n",
841
+ " <td>3.590872</td>\n",
842
+ " <td>1.894960</td>\n",
843
+ " <td>1.142273</td>\n",
844
+ " <td>0.016882</td>\n",
845
+ " <td>48.410091</td>\n",
846
+ " </tr>\n",
847
+ " <tr>\n",
848
+ " <td>5</td>\n",
849
+ " <td>0.084600</td>\n",
850
+ " <td>3.600572</td>\n",
851
+ " <td>3.600573</td>\n",
852
+ " <td>1.897517</td>\n",
853
+ " <td>1.145824</td>\n",
854
+ " <td>0.014226</td>\n",
855
+ " <td>48.548377</td>\n",
856
+ " </tr>\n",
857
+ " </tbody>\n",
858
+ "</table><p>"
859
+ ],
860
+ "text/plain": [
861
+ "<IPython.core.display.HTML object>"
862
+ ]
863
+ },
864
+ "metadata": {},
865
+ "output_type": "display_data"
866
+ },
867
+ {
868
+ "name": "stderr",
869
+ "output_type": "stream",
870
+ "text": [
871
+ "[I 2025-01-09 12:30:33,965] Trial 3 finished with value: 1.8975174797770824 and parameters: {'learning_rate': 1.1750268648920993e-06}. Best is trial 2 with value: 1.6214995462309338.\n",
872
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
873
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
874
+ ]
875
+ },
876
+ {
877
+ "data": {
878
+ "text/html": [
879
+ "\n",
880
+ " <div>\n",
881
+ " \n",
882
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
883
+ " [620/620 03:27, Epoch 5/5]\n",
884
+ " </div>\n",
885
+ " <table border=\"1\" class=\"dataframe\">\n",
886
+ " <thead>\n",
887
+ " <tr style=\"text-align: left;\">\n",
888
+ " <th>Epoch</th>\n",
889
+ " <th>Training Loss</th>\n",
890
+ " <th>Validation Loss</th>\n",
891
+ " <th>Mse</th>\n",
892
+ " <th>Rmse</th>\n",
893
+ " <th>Mae</th>\n",
894
+ " <th>R2</th>\n",
895
+ " <th>Smape</th>\n",
896
+ " </tr>\n",
897
+ " </thead>\n",
898
+ " <tbody>\n",
899
+ " <tr>\n",
900
+ " <td>1</td>\n",
901
+ " <td>0.085600</td>\n",
902
+ " <td>3.761341</td>\n",
903
+ " <td>3.761341</td>\n",
904
+ " <td>1.939418</td>\n",
905
+ " <td>1.156432</td>\n",
906
+ " <td>-0.029790</td>\n",
907
+ " <td>46.601269</td>\n",
908
+ " </tr>\n",
909
+ " <tr>\n",
910
+ " <td>2</td>\n",
911
+ " <td>2.913400</td>\n",
912
+ " <td>3.756832</td>\n",
913
+ " <td>3.756831</td>\n",
914
+ " <td>1.938255</td>\n",
915
+ " <td>1.238454</td>\n",
916
+ " <td>-0.028555</td>\n",
917
+ " <td>49.874967</td>\n",
918
+ " </tr>\n",
919
+ " <tr>\n",
920
+ " <td>3</td>\n",
921
+ " <td>32.276600</td>\n",
922
+ " <td>3.654472</td>\n",
923
+ " <td>3.654473</td>\n",
924
+ " <td>1.911668</td>\n",
925
+ " <td>1.135091</td>\n",
926
+ " <td>-0.000531</td>\n",
927
+ " <td>48.732340</td>\n",
928
+ " </tr>\n",
929
+ " <tr>\n",
930
+ " <td>4</td>\n",
931
+ " <td>0.083000</td>\n",
932
+ " <td>3.665871</td>\n",
933
+ " <td>3.665871</td>\n",
934
+ " <td>1.914646</td>\n",
935
+ " <td>1.162767</td>\n",
936
+ " <td>-0.003652</td>\n",
937
+ " <td>49.439710</td>\n",
938
+ " </tr>\n",
939
+ " <tr>\n",
940
+ " <td>5</td>\n",
941
+ " <td>0.055800</td>\n",
942
+ " <td>3.610057</td>\n",
943
+ " <td>3.610057</td>\n",
944
+ " <td>1.900015</td>\n",
945
+ " <td>1.183222</td>\n",
946
+ " <td>0.011629</td>\n",
947
+ " <td>49.474382</td>\n",
948
+ " </tr>\n",
949
+ " </tbody>\n",
950
+ "</table><p>"
951
+ ],
952
+ "text/plain": [
953
+ "<IPython.core.display.HTML object>"
954
+ ]
955
+ },
956
+ "metadata": {},
957
+ "output_type": "display_data"
958
+ },
959
+ {
960
+ "name": "stderr",
961
+ "output_type": "stream",
962
+ "text": [
963
+ "[I 2025-01-09 12:34:05,271] Trial 4 finished with value: 1.9000149676084739 and parameters: {'learning_rate': 2.308984942228097e-06}. Best is trial 2 with value: 1.6214995462309338.\n"
964
+ ]
965
+ }
966
+ ],
967
+ "source": [
968
+ "best_trial = hp_trainer.hyperparameter_search(\n",
969
+ " direction=\"minimize\",\n",
970
+ " backend=\"optuna\",\n",
971
+ " hp_space=optuna_hp_space,\n",
972
+ " n_trials=5,\n",
973
+ " compute_objective=lambda x: x['eval_rmse'],\n",
974
+ ")"
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": 26,
980
+ "id": "c9c39fa6-3f84-4082-879a-7efd5d21e174",
981
+ "metadata": {},
982
+ "outputs": [
983
+ {
984
+ "data": {
985
+ "text/plain": [
986
+ "BestRun(run_id='2', objective=1.6214995462309338, hyperparameters={'learning_rate': 2.479942619764035e-05}, run_summary=None)"
987
+ ]
988
+ },
989
+ "execution_count": 26,
990
+ "metadata": {},
991
+ "output_type": "execute_result"
992
+ }
993
+ ],
994
+ "source": [
995
+ "best_trial"
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "markdown",
1000
+ "id": "f511f354-5c3e-4c62-8063-f769a6c1b9ca",
1001
+ "metadata": {},
1002
+ "source": [
1003
+ "### Fit and upload the best Model\n",
1004
+ "We re-fit the model with the best hyperparameters in accordaince with this [forum post](https://discuss.huggingface.co/t/how-to-save-the-best-trials-model-using-trainer-hyperparameter-search/8783/4)"
1005
+ ]
1006
+ },
1007
+ {
1008
+ "cell_type": "code",
1009
+ "execution_count": 27,
1010
+ "id": "ad4e4cf2-286e-4b4a-9c3b-2024be1769b8",
1011
+ "metadata": {},
1012
+ "outputs": [
1013
+ {
1014
+ "name": "stderr",
1015
+ "output_type": "stream",
1016
+ "text": [
1017
+ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
1018
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
1019
+ "/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
1020
+ " warnings.warn(\n"
1021
+ ]
1022
+ }
1023
+ ],
1024
+ "source": [
1025
+ "best_trainer = Trainer(\n",
1026
+ " model=model_init(None),\n",
1027
+ " args=gen_training_args({**best_trial.hyperparameters}),\n",
1028
+ " train_dataset=tokenized_dataset[\"train\"],\n",
1029
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
1030
+ " compute_metrics=compute_metrics_for_regression,\n",
1031
+ ")"
1032
+ ]
1033
+ },
1034
+ {
1035
+ "cell_type": "code",
1036
+ "execution_count": 28,
1037
+ "id": "a031f566-4be1-440a-8481-f23e609ac3b3",
1038
+ "metadata": {},
1039
+ "outputs": [
1040
+ {
1041
+ "data": {
1042
+ "text/html": [
1043
+ "\n",
1044
+ " <div>\n",
1045
+ " \n",
1046
+ " <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1047
+ " [620/620 03:25, Epoch 5/5]\n",
1048
+ " </div>\n",
1049
+ " <table border=\"1\" class=\"dataframe\">\n",
1050
+ " <thead>\n",
1051
+ " <tr style=\"text-align: left;\">\n",
1052
+ " <th>Epoch</th>\n",
1053
+ " <th>Training Loss</th>\n",
1054
+ " <th>Validation Loss</th>\n",
1055
+ " <th>Mse</th>\n",
1056
+ " <th>Rmse</th>\n",
1057
+ " <th>Mae</th>\n",
1058
+ " <th>R2</th>\n",
1059
+ " <th>Smape</th>\n",
1060
+ " </tr>\n",
1061
+ " </thead>\n",
1062
+ " <tbody>\n",
1063
+ " <tr>\n",
1064
+ " <td>1</td>\n",
1065
+ " <td>0.115200</td>\n",
1066
+ " <td>4.084211</td>\n",
1067
+ " <td>4.084211</td>\n",
1068
+ " <td>2.020943</td>\n",
1069
+ " <td>1.219903</td>\n",
1070
+ " <td>-0.118186</td>\n",
1071
+ " <td>49.023473</td>\n",
1072
+ " </tr>\n",
1073
+ " <tr>\n",
1074
+ " <td>2</td>\n",
1075
+ " <td>1.239000</td>\n",
1076
+ " <td>3.803578</td>\n",
1077
+ " <td>3.803578</td>\n",
1078
+ " <td>1.950276</td>\n",
1079
+ " <td>1.289222</td>\n",
1080
+ " <td>-0.041354</td>\n",
1081
+ " <td>52.775413</td>\n",
1082
+ " </tr>\n",
1083
+ " <tr>\n",
1084
+ " <td>3</td>\n",
1085
+ " <td>27.825600</td>\n",
1086
+ " <td>3.245966</td>\n",
1087
+ " <td>3.245967</td>\n",
1088
+ " <td>1.801657</td>\n",
1089
+ " <td>1.102216</td>\n",
1090
+ " <td>0.111311</td>\n",
1091
+ " <td>51.747030</td>\n",
1092
+ " </tr>\n",
1093
+ " <tr>\n",
1094
+ " <td>4</td>\n",
1095
+ " <td>0.000100</td>\n",
1096
+ " <td>2.413429</td>\n",
1097
+ " <td>2.413429</td>\n",
1098
+ " <td>1.553521</td>\n",
1099
+ " <td>1.081085</td>\n",
1100
+ " <td>0.339245</td>\n",
1101
+ " <td>52.221513</td>\n",
1102
+ " </tr>\n",
1103
+ " <tr>\n",
1104
+ " <td>5</td>\n",
1105
+ " <td>0.166600</td>\n",
1106
+ " <td>2.462405</td>\n",
1107
+ " <td>2.462406</td>\n",
1108
+ " <td>1.569205</td>\n",
1109
+ " <td>1.182182</td>\n",
1110
+ " <td>0.325836</td>\n",
1111
+ " <td>56.614470</td>\n",
1112
+ " </tr>\n",
1113
+ " </tbody>\n",
1114
+ "</table><p>"
1115
+ ],
1116
+ "text/plain": [
1117
+ "<IPython.core.display.HTML object>"
1118
+ ]
1119
+ },
1120
+ "metadata": {},
1121
+ "output_type": "display_data"
1122
+ },
1123
+ {
1124
+ "data": {
1125
+ "text/plain": [
1126
+ "TrainOutput(global_step=620, training_loss=4.329616037725622, metrics={'train_runtime': 205.4329, 'train_samples_per_second': 11.999, 'train_steps_per_second': 3.018, 'total_flos': 3359849068769280.0, 'train_loss': 4.329616037725622, 'epoch': 5.0})"
1127
+ ]
1128
+ },
1129
+ "execution_count": 28,
1130
+ "metadata": {},
1131
+ "output_type": "execute_result"
1132
+ }
1133
+ ],
1134
+ "source": [
1135
+ "best_trainer.train() "
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "code",
1140
+ "execution_count": 29,
1141
+ "id": "09a0f8e2-7986-4171-902e-c08bcd5d6088",
1142
+ "metadata": {},
1143
+ "outputs": [
1144
+ {
1145
+ "data": {
1146
+ "text/html": [
1147
+ "\n",
1148
+ " <div>\n",
1149
+ " \n",
1150
+ " <progress value='14' max='14' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1151
+ " [14/14 00:01]\n",
1152
+ " </div>\n",
1153
+ " "
1154
+ ],
1155
+ "text/plain": [
1156
+ "<IPython.core.display.HTML object>"
1157
+ ]
1158
+ },
1159
+ "metadata": {},
1160
+ "output_type": "display_data"
1161
+ },
1162
+ {
1163
+ "data": {
1164
+ "text/plain": [
1165
+ "{'eval_loss': 2.4624054431915283,\n",
1166
+ " 'eval_mse': 2.4624056816101074,\n",
1167
+ " 'eval_rmse': 1.5692054300218654,\n",
1168
+ " 'eval_mae': 1.182181715965271,\n",
1169
+ " 'eval_r2': 0.325836181640625,\n",
1170
+ " 'eval_smape': 56.61447048187256,\n",
1171
+ " 'eval_runtime': 1.3489,\n",
1172
+ " 'eval_samples_per_second': 40.774,\n",
1173
+ " 'eval_steps_per_second': 10.379,\n",
1174
+ " 'epoch': 5.0}"
1175
+ ]
1176
+ },
1177
+ "execution_count": 29,
1178
+ "metadata": {},
1179
+ "output_type": "execute_result"
1180
+ }
1181
+ ],
1182
+ "source": [
1183
+ "best_trainer.evaluate()"
1184
+ ]
1185
+ },
1186
+ {
1187
+ "cell_type": "code",
1188
+ "execution_count": 30,
1189
+ "id": "6360b8dd-c456-4c0f-a79f-b1fb7a99ad19",
1190
+ "metadata": {},
1191
+ "outputs": [
1192
+ {
1193
+ "data": {
1194
+ "application/vnd.jupyter.widget-view+json": {
1195
+ "model_id": "49576b7f5f6b4ea781dd6198df4f33f7",
1196
+ "version_major": 2,
1197
+ "version_minor": 0
1198
+ },
1199
+ "text/plain": [
1200
+ "events.out.tfevents.1736455080.bazzite: 0%| | 0.00/40.0 [00:00<?, ?B/s]"
1201
+ ]
1202
+ },
1203
+ "metadata": {},
1204
+ "output_type": "display_data"
1205
+ },
1206
+ {
1207
+ "data": {
1208
+ "text/plain": [
1209
+ "CommitInfo(commit_url='https://huggingface.co/Forecast-ing/modernBERT-content-regression/commit/16f1dc87782b2735f8fef84a5b10807b6cbe5565', commit_message='End of training', commit_description='', oid='16f1dc87782b2735f8fef84a5b10807b6cbe5565', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Forecast-ing/modernBERT-content-regression', endpoint='https://huggingface.co', repo_type='model', repo_id='Forecast-ing/modernBERT-content-regression'), pr_revision=None, pr_num=None)"
1210
+ ]
1211
+ },
1212
+ "execution_count": 30,
1213
+ "metadata": {},
1214
+ "output_type": "execute_result"
1215
+ }
1216
+ ],
1217
+ "source": [
1218
+ "tokenizer.save_pretrained(\"modernBERT-content-regression\")\n",
1219
+ "best_trainer.create_model_card()\n",
1220
+ "best_trainer.push_to_hub()"
1221
+ ]
1222
+ },
1223
+ {
1224
+ "cell_type": "code",
1225
+ "execution_count": null,
1226
+ "id": "2c9971d5-5db6-48ac-be23-4131f75a961c",
1227
+ "metadata": {},
1228
+ "outputs": [],
1229
+ "source": []
1230
+ },
1231
+ {
1232
+ "cell_type": "code",
1233
+ "execution_count": null,
1234
+ "id": "083db2d4-2601-44ef-aa87-5f4a1d66f8e3",
1235
+ "metadata": {},
1236
+ "outputs": [],
1237
+ "source": []
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": null,
1242
+ "id": "9db9a38d-653d-4b71-8cc2-e0f65ae065bd",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": []
1246
+ }
1247
+ ],
1248
+ "metadata": {
1249
+ "kernelspec": {
1250
+ "display_name": "Python 3 (ipykernel)",
1251
+ "language": "python",
1252
+ "name": "python3"
1253
+ },
1254
+ "language_info": {
1255
+ "codemirror_mode": {
1256
+ "name": "ipython",
1257
+ "version": 3
1258
+ },
1259
+ "file_extension": ".py",
1260
+ "mimetype": "text/x-python",
1261
+ "name": "python",
1262
+ "nbconvert_exporter": "python",
1263
+ "pygments_lexer": "ipython3",
1264
+ "version": "3.12.8"
1265
+ }
1266
+ },
1267
+ "nbformat": 4,
1268
+ "nbformat_minor": 5
1269
+ }
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "modernbert-content-regression"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate>=1.2.1",
9
+ "catboost>=1.2.7",
10
+ "datasets>=3.2.0",
11
+ "evaluate>=0.4.3",
12
+ "hf-transfer>=0.1.8",
13
+ "optuna>=4.1.0",
14
+ "python-dotenv>=1.0.1",
15
+ "scikit-learn>=1.6.0",
16
+ "tensorboardx>=2.6.2.2",
17
+ "transformers",
18
+ ]
19
+
20
+ [tool.uv.sources]
21
+ tinytroupe = { git = "https://github.com/microsoft/TinyTroupe.git" }
22
+ transformers = { git = "https://github.com/huggingface/transformers.git", rev = "6e0515e99c39444caae39472ee1b2fd76ece32f1" }
23
+
24
+ [dependency-groups]
25
+ dev = [
26
+ "ruff>=0.8.4",
27
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff