Robin Tully
commited on
Commit
·
9586c0c
1
Parent(s):
20d9458
training notebook
Browse files- .gitignore +167 -0
- .python-version +1 -0
- README.md +65 -1
- model_training.ipynb +1269 -0
- pyproject.toml +27 -0
- uv.lock +0 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
.ruff_cache
|
3 |
+
catboost_info/
|
4 |
+
modernBERT-content-regression/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
115 |
+
.pdm.toml
|
116 |
+
.pdm-python
|
117 |
+
.pdm-build/
|
118 |
+
|
119 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
120 |
+
__pypackages__/
|
121 |
+
|
122 |
+
# Celery stuff
|
123 |
+
celerybeat-schedule
|
124 |
+
celerybeat.pid
|
125 |
+
|
126 |
+
# SageMath parsed files
|
127 |
+
*.sage.py
|
128 |
+
|
129 |
+
# Environments
|
130 |
+
.env
|
131 |
+
.venv
|
132 |
+
env/
|
133 |
+
venv/
|
134 |
+
ENV/
|
135 |
+
env.bak/
|
136 |
+
venv.bak/
|
137 |
+
|
138 |
+
# Spyder project settings
|
139 |
+
.spyderproject
|
140 |
+
.spyproject
|
141 |
+
|
142 |
+
# Rope project settings
|
143 |
+
.ropeproject
|
144 |
+
|
145 |
+
# mkdocs documentation
|
146 |
+
/site
|
147 |
+
|
148 |
+
# mypy
|
149 |
+
.mypy_cache/
|
150 |
+
.dmypy.json
|
151 |
+
dmypy.json
|
152 |
+
|
153 |
+
# Pyre type checker
|
154 |
+
.pyre/
|
155 |
+
|
156 |
+
# pytype static type analyzer
|
157 |
+
.pytype/
|
158 |
+
|
159 |
+
# Cython debug symbols
|
160 |
+
cython_debug/
|
161 |
+
|
162 |
+
# PyCharm
|
163 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
164 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
165 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
166 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
167 |
+
#.idea/
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
README.md
CHANGED
@@ -1 +1,65 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### What is this?
|
3 |
+
|
4 |
+
This is an exploration of using modernBERT for the text regression task of predicting engagement metrics for text content. In this case, we are predicting the clickthrough rate (CTR) of email text content.
|
5 |
+
|
6 |
+
We will be exploring hyperparameter tuning of modernBert; and how to use it for regression, as well as comparing the results to a benchmark model.
|
7 |
+
|
8 |
+
This type of task if difficult, we can remember the quote
|
9 |
+
> “Half my advertising is wasted; the trouble is, I don't know which half”
|
10 |
+
> -John Wanamaker
|
11 |
+
|
12 |
+
We are also excluding other relevant factors such as the time of day the email is sent, the day of the week, the recipient, etc in this experiment.
|
13 |
+
|
14 |
+
This work is indebted to the work of many community members and blog posts.
|
15 |
+
- [ModernBERT Aanouncement](https://huggingface.co/blog/modernbert)
|
16 |
+
- [Fine-tune classifier with ModernBERT in 2025](https://www.philschmid.de/fine-tune-modern-bert-in-2025)
|
17 |
+
- [How to set up Trainer for a regression](https://discuss.huggingface.co/t/how-to-set-up-trainer-for-a-regression/12994)
|
18 |
+
|
19 |
+
### Our dataset
|
20 |
+
We will be using a dataset of 548 emails where we have the text of the email `text` and the CTR we are trying to predict `labels`.
|
21 |
+
|
22 |
+
We look forward in the improvements of ModernBERT to fine-tune models specifically for each potential users email dataset. The variability of email data, as well as the small size of the dataset pose an interesting regression challenge.
|
23 |
+
|
24 |
+
### Benchmarking
|
25 |
+
We will start by using the Catboost library as a simple benchmark for text regression. For both the benchmark and the ModernBert run, we are using 'rmse' as the metric.
|
26 |
+
We recieve the following results:
|
27 |
+
| Metric | Value |
|
28 |
+
|--------|------------------|
|
29 |
+
| MSE | 2.552100633998035 |
|
30 |
+
| RMSE | 1.5975295408843102 |
|
31 |
+
| MAE | 1.1439370629666958 |
|
32 |
+
| R² | 0.30127932054387174 |
|
33 |
+
| SMAPE | 37.63064694052479 |
|
34 |
+
|
35 |
+
## Fitting the Modern Bert Model
|
36 |
+
|
37 |
+
### Install dependencies and activate venv
|
38 |
+
```bash
|
39 |
+
uv sync
|
40 |
+
source .venv/bin/activate
|
41 |
+
```
|
42 |
+
the following values need to be defined in the .env file
|
43 |
+
- `HUGGINGFACE_TOKEN`
|
44 |
+
|
45 |
+
### Run notebook for model fitting
|
46 |
+
|
47 |
+
```bash
|
48 |
+
uv run --with jupyter jupyter lab
|
49 |
+
```
|
50 |
+
|
51 |
+
### ModernBert Model Performance
|
52 |
+
After running hyperparameter tuning for ModernBERT, we get the following results:
|
53 |
+
|
54 |
+
| Metric | Value |
|
55 |
+
|--------|------------------|
|
56 |
+
| MSE | 2.4624056816101074 |
|
57 |
+
| RMSE | 1.5692054300218654 |
|
58 |
+
| MAE | 1.182181715965271 |
|
59 |
+
| R² | 0.325836181640625 |
|
60 |
+
| SMAPE | 56.61447048187256 |
|
61 |
+
|
62 |
+
We see improvements in all metrics except for SMAPE. We believe that ModernBERT would scale even better with a larger dataset; as 500 example is very low for fine-tuning and are thus happy with the performance of this evaluation.
|
63 |
+
|
64 |
+
## Conclusion
|
65 |
+
We see that ModernBERT is a powerful model for text regression. We believe that with a larger dataset, we would see even better results. We are excited to see the future of ModernBERT and how it will be used for text regression.
|
model_training.ipynb
ADDED
@@ -0,0 +1,1269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "51a9bb64-969a-4fc0-aa76-4bd42b08c21a",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"data": {
|
11 |
+
"text/plain": [
|
12 |
+
"True"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
"execution_count": 1,
|
16 |
+
"metadata": {},
|
17 |
+
"output_type": "execute_result"
|
18 |
+
}
|
19 |
+
],
|
20 |
+
"source": [
|
21 |
+
"import os\n",
|
22 |
+
"\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"import pandas as pd\n",
|
25 |
+
"from catboost import CatBoostRegressor, Pool\n",
|
26 |
+
"from datasets import load_dataset\n",
|
27 |
+
"from dotenv import load_dotenv\n",
|
28 |
+
"from huggingface_hub import HfFolder, login\n",
|
29 |
+
"from sklearn.metrics import (\n",
|
30 |
+
" mean_absolute_error,\n",
|
31 |
+
" mean_squared_error,\n",
|
32 |
+
" r2_score,\n",
|
33 |
+
" root_mean_squared_error,\n",
|
34 |
+
")\n",
|
35 |
+
"from transformers import (\n",
|
36 |
+
" AutoModelForSequenceClassification,\n",
|
37 |
+
" AutoTokenizer,\n",
|
38 |
+
" Trainer,\n",
|
39 |
+
" TrainingArguments,\n",
|
40 |
+
")\n",
|
41 |
+
"\n",
|
42 |
+
"load_dotenv()"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 2,
|
48 |
+
"id": "9b53c782-5dbb-4dd6-b541-8e4fab3f3ddf",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"login(token=os.getenv(\"HUGGINGFACE_API_KEY\"))"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"id": "87f58ec1-231d-4c17-8af2-c366af55e375",
|
58 |
+
"metadata": {},
|
59 |
+
"source": [
|
60 |
+
"### Dataset prep"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 3,
|
66 |
+
"id": "0ea777d8-988b-421a-8d76-b0f9256ab61b",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"raw_dataset = load_dataset(\"Forecast-ing/email-clickthrough\")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 4,
|
76 |
+
"id": "0c8441e7-1606-4c61-8f6c-34ebe1e107c0",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"raw_dataset = raw_dataset.rename_column(\"label\", \"labels\")"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 5,
|
86 |
+
"id": "201b806d-6c94-4053-98b6-4d22dcdda08a",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [
|
89 |
+
{
|
90 |
+
"data": {
|
91 |
+
"text/plain": [
|
92 |
+
"3292"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
"execution_count": 5,
|
96 |
+
"metadata": {},
|
97 |
+
"output_type": "execute_result"
|
98 |
+
}
|
99 |
+
],
|
100 |
+
"source": [
|
101 |
+
"raw_dataset[\"train\"].to_pandas()[\"text\"].str.len().max()"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 6,
|
107 |
+
"id": "cbd4d6ec-b293-49f1-925f-239549dab61e",
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [
|
110 |
+
{
|
111 |
+
"data": {
|
112 |
+
"text/plain": [
|
113 |
+
"0.2427007299270073"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
"execution_count": 6,
|
117 |
+
"metadata": {},
|
118 |
+
"output_type": "execute_result"
|
119 |
+
}
|
120 |
+
],
|
121 |
+
"source": [
|
122 |
+
"(raw_dataset[\"train\"].to_pandas()[\"text\"].str.len() > 2048).mean()"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 7,
|
128 |
+
"id": "1101c038-3f83-4055-b938-3861ac43cf8f",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [
|
131 |
+
{
|
132 |
+
"data": {
|
133 |
+
"text/plain": [
|
134 |
+
"count 548.000000\n",
|
135 |
+
"mean 2.879635\n",
|
136 |
+
"std 2.423870\n",
|
137 |
+
"min 0.450000\n",
|
138 |
+
"25% 1.510000\n",
|
139 |
+
"50% 2.025000\n",
|
140 |
+
"75% 3.267500\n",
|
141 |
+
"max 25.370000\n",
|
142 |
+
"Name: labels, dtype: float64"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
"execution_count": 7,
|
146 |
+
"metadata": {},
|
147 |
+
"output_type": "execute_result"
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"source": [
|
151 |
+
"raw_dataset[\"train\"].to_pandas()[\"labels\"].describe()"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 8,
|
157 |
+
"id": "7546577f-4d7f-41b5-a68f-095fc0e8eec4",
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"raw_dataset = raw_dataset[\"train\"].train_test_split(test_size=0.1, seed=1)"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": 9,
|
167 |
+
"id": "a0c28ab6-20f8-47dc-8ee9-179ea15830e0",
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [
|
170 |
+
{
|
171 |
+
"name": "stdout",
|
172 |
+
"output_type": "stream",
|
173 |
+
"text": [
|
174 |
+
"Train dataset size: 493\n",
|
175 |
+
"Test dataset size: 55\n"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"source": [
|
180 |
+
"print(f\"Train dataset size: {len(raw_dataset['train'])}\")\n",
|
181 |
+
"print(f\"Test dataset size: {len(raw_dataset['test'])}\")"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "markdown",
|
186 |
+
"id": "a2c3e7c3-e31d-4d8f-8d7d-e62050a9ae9d",
|
187 |
+
"metadata": {},
|
188 |
+
"source": [
|
189 |
+
"### Catboost Benchmark"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 10,
|
195 |
+
"id": "01aaea26-e1df-4493-b9cd-732c3b7a76a9",
|
196 |
+
"metadata": {},
|
197 |
+
"outputs": [],
|
198 |
+
"source": [
|
199 |
+
"catboost_train = raw_dataset[\"train\"].to_pandas()\n",
|
200 |
+
"catboost_test = raw_dataset[\"test\"].to_pandas()"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 11,
|
206 |
+
"id": "0243e07d-69ba-41e5-b54d-d0f4988bbf9f",
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [],
|
209 |
+
"source": [
|
210 |
+
"text_columns = [\"text\"]\n",
|
211 |
+
"label = \"labels\""
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 12,
|
217 |
+
"id": "ba17040c-5882-47dc-a8af-a7557356840f",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"train_pool = Pool(\n",
|
222 |
+
" data=catboost_train[text_columns],\n",
|
223 |
+
" label=catboost_train[label],\n",
|
224 |
+
" text_features=text_columns,\n",
|
225 |
+
")\n",
|
226 |
+
"test_pool = Pool(\n",
|
227 |
+
" data=catboost_test[text_columns],\n",
|
228 |
+
" label=catboost_test[label],\n",
|
229 |
+
" text_features=text_columns,\n",
|
230 |
+
")"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 13,
|
236 |
+
"id": "d8b3768e-6f30-41ce-a209-bd915a997d8a",
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [
|
239 |
+
{
|
240 |
+
"name": "stdout",
|
241 |
+
"output_type": "stream",
|
242 |
+
"text": [
|
243 |
+
"Learning rate set to 0.045569\n",
|
244 |
+
"0:\tlearn: 2.4332854\ttest: 1.8670741\tbest: 1.8670741 (0)\ttotal: 60.5ms\tremaining: 1m\n",
|
245 |
+
"100:\tlearn: 1.4972558\ttest: 1.6247590\tbest: 1.6048404 (59)\ttotal: 2.5s\tremaining: 22.2s\n",
|
246 |
+
"200:\tlearn: 1.1104040\ttest: 1.6015944\tbest: 1.5975296 (197)\ttotal: 4.91s\tremaining: 19.5s\n",
|
247 |
+
"300:\tlearn: 0.8568033\ttest: 1.6102309\tbest: 1.5975296 (197)\ttotal: 7.33s\tremaining: 17s\n",
|
248 |
+
"400:\tlearn: 0.7096792\ttest: 1.6090190\tbest: 1.5975296 (197)\ttotal: 9.72s\tremaining: 14.5s\n",
|
249 |
+
"500:\tlearn: 0.6056532\ttest: 1.6083240\tbest: 1.5975296 (197)\ttotal: 12.1s\tremaining: 12s\n",
|
250 |
+
"600:\tlearn: 0.5298016\ttest: 1.6175366\tbest: 1.5975296 (197)\ttotal: 14.5s\tremaining: 9.64s\n",
|
251 |
+
"700:\tlearn: 0.4701467\ttest: 1.6262668\tbest: 1.5975296 (197)\ttotal: 16.9s\tremaining: 7.23s\n",
|
252 |
+
"800:\tlearn: 0.4233732\ttest: 1.6199203\tbest: 1.5975296 (197)\ttotal: 19.4s\tremaining: 4.81s\n",
|
253 |
+
"900:\tlearn: 0.3837074\ttest: 1.6104091\tbest: 1.5975296 (197)\ttotal: 21.8s\tremaining: 2.39s\n",
|
254 |
+
"999:\tlearn: 0.3501113\ttest: 1.6131207\tbest: 1.5975296 (197)\ttotal: 24.2s\tremaining: 0us\n",
|
255 |
+
"\n",
|
256 |
+
"bestTest = 1.597529566\n",
|
257 |
+
"bestIteration = 197\n",
|
258 |
+
"\n",
|
259 |
+
"Shrink model to first 198 iterations.\n"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"data": {
|
264 |
+
"text/plain": [
|
265 |
+
"<catboost.core.CatBoostRegressor at 0x7fb1061c5bb0>"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
"execution_count": 13,
|
269 |
+
"metadata": {},
|
270 |
+
"output_type": "execute_result"
|
271 |
+
}
|
272 |
+
],
|
273 |
+
"source": [
|
274 |
+
"model = CatBoostRegressor(loss_function=\"RMSE\", verbose=100)\n",
|
275 |
+
"\n",
|
276 |
+
"model.fit(train_pool, eval_set=test_pool)"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "code",
|
281 |
+
"execution_count": 14,
|
282 |
+
"id": "837b22a8-241d-49b3-a1ae-915893121319",
|
283 |
+
"metadata": {},
|
284 |
+
"outputs": [],
|
285 |
+
"source": [
|
286 |
+
"y_pred = model.predict(test_pool)\n",
|
287 |
+
"y_val = catboost_test[label]"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": 15,
|
293 |
+
"id": "478521bf-85be-49a1-8461-020587f146d2",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"def smape(y_true, y_pred):\n",
|
298 |
+
" return 100 * np.mean(\n",
|
299 |
+
" 2 * np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred))\n",
|
300 |
+
" )\n",
|
301 |
+
"\n",
|
302 |
+
"\n",
|
303 |
+
"def calculate_metrics(y_val, y_pred):\n",
|
304 |
+
" mse = mean_squared_error(y_val, y_pred)\n",
|
305 |
+
" rmse = np.sqrt(mse)\n",
|
306 |
+
" mae = mean_absolute_error(y_val, y_pred)\n",
|
307 |
+
" r2 = r2_score(y_val, y_pred)\n",
|
308 |
+
" smape_value = smape(y_val, y_pred)\n",
|
309 |
+
" return {\n",
|
310 |
+
" \"mse\": mse,\n",
|
311 |
+
" \"rmse\": rmse,\n",
|
312 |
+
" \"mae\": mae,\n",
|
313 |
+
" \"r2\": r2,\n",
|
314 |
+
" \"smape\": smape_value,\n",
|
315 |
+
" }"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": 16,
|
321 |
+
"id": "91ea95e2-1818-45a6-8725-3a1353cb5b97",
|
322 |
+
"metadata": {},
|
323 |
+
"outputs": [],
|
324 |
+
"source": [
|
325 |
+
"catboost_metrics = calculate_metrics(y_val, y_pred)"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": 17,
|
331 |
+
"id": "e28e359e-f69c-4ee8-9bbd-e7afafafcd26",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [
|
334 |
+
{
|
335 |
+
"data": {
|
336 |
+
"text/plain": [
|
337 |
+
"{'mse': 2.552100633998035,\n",
|
338 |
+
" 'rmse': 1.5975295408843102,\n",
|
339 |
+
" 'mae': 1.1439370629666958,\n",
|
340 |
+
" 'r2': 0.30127932054387174,\n",
|
341 |
+
" 'smape': 37.63064694052479}"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
"execution_count": 17,
|
345 |
+
"metadata": {},
|
346 |
+
"output_type": "execute_result"
|
347 |
+
}
|
348 |
+
],
|
349 |
+
"source": [
|
350 |
+
"catboost_metrics"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "markdown",
|
355 |
+
"id": "7afac97a-1e69-47e4-9ecd-ece7ebe7b48f",
|
356 |
+
"metadata": {},
|
357 |
+
"source": [
|
358 |
+
"### Fine Tuning Modern Bert"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "code",
|
363 |
+
"execution_count": 18,
|
364 |
+
"id": "031df047-2c18-4ec9-a498-596a7cf965b7",
|
365 |
+
"metadata": {},
|
366 |
+
"outputs": [],
|
367 |
+
"source": [
|
368 |
+
"model_id = \"answerdotai/ModernBERT-base\"\n",
|
369 |
+
"\n",
|
370 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
371 |
+
"tokenizer.model_max_length = 2048\n",
|
372 |
+
"\n",
|
373 |
+
"def tokenize(batch):\n",
|
374 |
+
" return tokenizer(\n",
|
375 |
+
" batch[\"text\"], padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n",
|
376 |
+
" )"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "code",
|
381 |
+
"execution_count": 19,
|
382 |
+
"id": "bbb711e8-8da6-401b-a2bd-c1637731f6c9",
|
383 |
+
"metadata": {},
|
384 |
+
"outputs": [],
|
385 |
+
"source": [
|
386 |
+
"tokenized_dataset = raw_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": 20,
|
392 |
+
"id": "ca4c2942-bf9e-4579-82b9-654fbee85b54",
|
393 |
+
"metadata": {},
|
394 |
+
"outputs": [],
|
395 |
+
"source": [
|
396 |
+
"def model_init(trial):\n",
|
397 |
+
" model = AutoModelForSequenceClassification.from_pretrained(\n",
|
398 |
+
" model_id, num_labels=1, ignore_mismatched_sizes=True, problem_type=\"regression\"\n",
|
399 |
+
" )\n",
|
400 |
+
" return model"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": 21,
|
406 |
+
"id": "a2005499-e139-4151-9ebe-2759710149b1",
|
407 |
+
"metadata": {},
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"def gen_training_args(additional_args={}):\n",
|
411 |
+
" default_args = {\n",
|
412 |
+
" \"output_dir\": \"./modernBERT-content-regression\",\n",
|
413 |
+
" \"per_device_eval_batch_size\": 4,\n",
|
414 |
+
" \"per_device_train_batch_size\": 4,\n",
|
415 |
+
" \"num_train_epochs\": 5,\n",
|
416 |
+
" \"bf16\": True, # bfloat16 training\n",
|
417 |
+
" \"optim\": \"adamw_torch_fused\", # improved optimizer\n",
|
418 |
+
" \"logging_strategy\": \"steps\",\n",
|
419 |
+
" \"logging_steps\": 1,\n",
|
420 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
421 |
+
" \"save_strategy\": \"epoch\",\n",
|
422 |
+
" \"save_total_limit\": 1,\n",
|
423 |
+
" \"metric_for_best_model\": \"rmse\",\n",
|
424 |
+
" \"greater_is_better\": False,\n",
|
425 |
+
" \"report_to\": \"tensorboard\",\n",
|
426 |
+
" \"push_to_hub\": True,\n",
|
427 |
+
" \"hub_private_repo\": True,\n",
|
428 |
+
" \"hub_strategy\": \"every_save\",\n",
|
429 |
+
" \"hub_token\": HfFolder.get_token(),\n",
|
430 |
+
" }\n",
|
431 |
+
" training_args = TrainingArguments(**default_args, **additional_args)\n",
|
432 |
+
" return training_args"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": 22,
|
438 |
+
"id": "e7e0ee17-9a1c-4789-8a4b-d2469a88837b",
|
439 |
+
"metadata": {},
|
440 |
+
"outputs": [],
|
441 |
+
"source": [
|
442 |
+
"def compute_metrics_for_regression(eval_pred):\n",
|
443 |
+
" predictions, labels = eval_pred\n",
|
444 |
+
" predictions = predictions.reshape(-1, 1)\n",
|
445 |
+
" results = calculate_metrics(labels, predictions)\n",
|
446 |
+
" return results\n"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"cell_type": "code",
|
451 |
+
"execution_count": 23,
|
452 |
+
"id": "10981164-fffc-4128-89ae-41c9238074cd",
|
453 |
+
"metadata": {},
|
454 |
+
"outputs": [
|
455 |
+
{
|
456 |
+
"name": "stderr",
|
457 |
+
"output_type": "stream",
|
458 |
+
"text": [
|
459 |
+
"/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
460 |
+
" warnings.warn(\n",
|
461 |
+
"/tmp/ipykernel_22314/2727960756.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
|
462 |
+
" hp_trainer = Trainer(\n",
|
463 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
464 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
465 |
+
]
|
466 |
+
}
|
467 |
+
],
|
468 |
+
"source": [
|
469 |
+
"hp_trainer = Trainer(\n",
|
470 |
+
" model=None,\n",
|
471 |
+
" args=gen_training_args(),\n",
|
472 |
+
" train_dataset=tokenized_dataset[\"train\"],\n",
|
473 |
+
" eval_dataset=tokenized_dataset[\"test\"],\n",
|
474 |
+
" tokenizer=tokenizer,\n",
|
475 |
+
" compute_metrics=compute_metrics_for_regression,\n",
|
476 |
+
" model_init=model_init,\n",
|
477 |
+
")"
|
478 |
+
]
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"cell_type": "code",
|
482 |
+
"execution_count": 24,
|
483 |
+
"id": "261a4988-e5f9-469f-b8ae-ef51ae6a95df",
|
484 |
+
"metadata": {},
|
485 |
+
"outputs": [],
|
486 |
+
"source": [
|
487 |
+
"def optuna_hp_space(trial):\n",
|
488 |
+
" return {\n",
|
489 |
+
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 5e-7, 5e-5, log=True),\n",
|
490 |
+
" }"
|
491 |
+
]
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "code",
|
495 |
+
"execution_count": 25,
|
496 |
+
"id": "7e7fb4f6-4eb4-4a4c-931d-a97c1614a5b9",
|
497 |
+
"metadata": {},
|
498 |
+
"outputs": [
|
499 |
+
{
|
500 |
+
"name": "stderr",
|
501 |
+
"output_type": "stream",
|
502 |
+
"text": [
|
503 |
+
"[I 2025-01-09 12:16:25,726] A new study created in memory with name: no-name-2f3f9073-d130-4bb1-9447-7262f2b7bd75\n",
|
504 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
505 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"data": {
|
510 |
+
"text/html": [
|
511 |
+
"\n",
|
512 |
+
" <div>\n",
|
513 |
+
" \n",
|
514 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
515 |
+
" [620/620 03:27, Epoch 5/5]\n",
|
516 |
+
" </div>\n",
|
517 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
518 |
+
" <thead>\n",
|
519 |
+
" <tr style=\"text-align: left;\">\n",
|
520 |
+
" <th>Epoch</th>\n",
|
521 |
+
" <th>Training Loss</th>\n",
|
522 |
+
" <th>Validation Loss</th>\n",
|
523 |
+
" <th>Mse</th>\n",
|
524 |
+
" <th>Rmse</th>\n",
|
525 |
+
" <th>Mae</th>\n",
|
526 |
+
" <th>R2</th>\n",
|
527 |
+
" <th>Smape</th>\n",
|
528 |
+
" </tr>\n",
|
529 |
+
" </thead>\n",
|
530 |
+
" <tbody>\n",
|
531 |
+
" <tr>\n",
|
532 |
+
" <td>1</td>\n",
|
533 |
+
" <td>0.238000</td>\n",
|
534 |
+
" <td>4.573008</td>\n",
|
535 |
+
" <td>4.573008</td>\n",
|
536 |
+
" <td>2.138459</td>\n",
|
537 |
+
" <td>1.324540</td>\n",
|
538 |
+
" <td>-0.252010</td>\n",
|
539 |
+
" <td>54.242009</td>\n",
|
540 |
+
" </tr>\n",
|
541 |
+
" <tr>\n",
|
542 |
+
" <td>2</td>\n",
|
543 |
+
" <td>3.768500</td>\n",
|
544 |
+
" <td>4.093452</td>\n",
|
545 |
+
" <td>4.093452</td>\n",
|
546 |
+
" <td>2.023228</td>\n",
|
547 |
+
" <td>1.458057</td>\n",
|
548 |
+
" <td>-0.120716</td>\n",
|
549 |
+
" <td>53.770840</td>\n",
|
550 |
+
" </tr>\n",
|
551 |
+
" <tr>\n",
|
552 |
+
" <td>3</td>\n",
|
553 |
+
" <td>27.661000</td>\n",
|
554 |
+
" <td>3.361875</td>\n",
|
555 |
+
" <td>3.361874</td>\n",
|
556 |
+
" <td>1.833541</td>\n",
|
557 |
+
" <td>1.126670</td>\n",
|
558 |
+
" <td>0.079577</td>\n",
|
559 |
+
" <td>52.641284</td>\n",
|
560 |
+
" </tr>\n",
|
561 |
+
" <tr>\n",
|
562 |
+
" <td>4</td>\n",
|
563 |
+
" <td>0.092300</td>\n",
|
564 |
+
" <td>2.759459</td>\n",
|
565 |
+
" <td>2.759459</td>\n",
|
566 |
+
" <td>1.661162</td>\n",
|
567 |
+
" <td>1.040074</td>\n",
|
568 |
+
" <td>0.244508</td>\n",
|
569 |
+
" <td>53.009331</td>\n",
|
570 |
+
" </tr>\n",
|
571 |
+
" <tr>\n",
|
572 |
+
" <td>5</td>\n",
|
573 |
+
" <td>0.020300</td>\n",
|
574 |
+
" <td>2.733250</td>\n",
|
575 |
+
" <td>2.733250</td>\n",
|
576 |
+
" <td>1.653254</td>\n",
|
577 |
+
" <td>1.078653</td>\n",
|
578 |
+
" <td>0.251684</td>\n",
|
579 |
+
" <td>54.187167</td>\n",
|
580 |
+
" </tr>\n",
|
581 |
+
" </tbody>\n",
|
582 |
+
"</table><p>"
|
583 |
+
],
|
584 |
+
"text/plain": [
|
585 |
+
"<IPython.core.display.HTML object>"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
"metadata": {},
|
589 |
+
"output_type": "display_data"
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"name": "stderr",
|
593 |
+
"output_type": "stream",
|
594 |
+
"text": [
|
595 |
+
"[I 2025-01-09 12:19:57,000] Trial 0 finished with value: 1.6532543369745685 and parameters: {'learning_rate': 1.9437267223645173e-05}. Best is trial 0 with value: 1.6532543369745685.\n",
|
596 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
597 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
598 |
+
]
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"data": {
|
602 |
+
"text/html": [
|
603 |
+
"\n",
|
604 |
+
" <div>\n",
|
605 |
+
" \n",
|
606 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
607 |
+
" [620/620 03:30, Epoch 5/5]\n",
|
608 |
+
" </div>\n",
|
609 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
610 |
+
" <thead>\n",
|
611 |
+
" <tr style=\"text-align: left;\">\n",
|
612 |
+
" <th>Epoch</th>\n",
|
613 |
+
" <th>Training Loss</th>\n",
|
614 |
+
" <th>Validation Loss</th>\n",
|
615 |
+
" <th>Mse</th>\n",
|
616 |
+
" <th>Rmse</th>\n",
|
617 |
+
" <th>Mae</th>\n",
|
618 |
+
" <th>R2</th>\n",
|
619 |
+
" <th>Smape</th>\n",
|
620 |
+
" </tr>\n",
|
621 |
+
" </thead>\n",
|
622 |
+
" <tbody>\n",
|
623 |
+
" <tr>\n",
|
624 |
+
" <td>1</td>\n",
|
625 |
+
" <td>0.033500</td>\n",
|
626 |
+
" <td>3.730757</td>\n",
|
627 |
+
" <td>3.730757</td>\n",
|
628 |
+
" <td>1.931517</td>\n",
|
629 |
+
" <td>1.167591</td>\n",
|
630 |
+
" <td>-0.021416</td>\n",
|
631 |
+
" <td>46.438679</td>\n",
|
632 |
+
" </tr>\n",
|
633 |
+
" <tr>\n",
|
634 |
+
" <td>2</td>\n",
|
635 |
+
" <td>3.021100</td>\n",
|
636 |
+
" <td>3.532418</td>\n",
|
637 |
+
" <td>3.532420</td>\n",
|
638 |
+
" <td>1.879473</td>\n",
|
639 |
+
" <td>1.171051</td>\n",
|
640 |
+
" <td>0.032885</td>\n",
|
641 |
+
" <td>48.273236</td>\n",
|
642 |
+
" </tr>\n",
|
643 |
+
" <tr>\n",
|
644 |
+
" <td>3</td>\n",
|
645 |
+
" <td>32.454400</td>\n",
|
646 |
+
" <td>3.670944</td>\n",
|
647 |
+
" <td>3.670944</td>\n",
|
648 |
+
" <td>1.915971</td>\n",
|
649 |
+
" <td>1.159171</td>\n",
|
650 |
+
" <td>-0.005041</td>\n",
|
651 |
+
" <td>48.529482</td>\n",
|
652 |
+
" </tr>\n",
|
653 |
+
" <tr>\n",
|
654 |
+
" <td>4</td>\n",
|
655 |
+
" <td>0.074300</td>\n",
|
656 |
+
" <td>3.690546</td>\n",
|
657 |
+
" <td>3.690546</td>\n",
|
658 |
+
" <td>1.921079</td>\n",
|
659 |
+
" <td>1.179955</td>\n",
|
660 |
+
" <td>-0.010407</td>\n",
|
661 |
+
" <td>49.107727</td>\n",
|
662 |
+
" </tr>\n",
|
663 |
+
" <tr>\n",
|
664 |
+
" <td>5</td>\n",
|
665 |
+
" <td>0.098800</td>\n",
|
666 |
+
" <td>3.677439</td>\n",
|
667 |
+
" <td>3.677439</td>\n",
|
668 |
+
" <td>1.917665</td>\n",
|
669 |
+
" <td>1.188619</td>\n",
|
670 |
+
" <td>-0.006819</td>\n",
|
671 |
+
" <td>49.251461</td>\n",
|
672 |
+
" </tr>\n",
|
673 |
+
" </tbody>\n",
|
674 |
+
"</table><p>"
|
675 |
+
],
|
676 |
+
"text/plain": [
|
677 |
+
"<IPython.core.display.HTML object>"
|
678 |
+
]
|
679 |
+
},
|
680 |
+
"metadata": {},
|
681 |
+
"output_type": "display_data"
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"name": "stderr",
|
685 |
+
"output_type": "stream",
|
686 |
+
"text": [
|
687 |
+
"[I 2025-01-09 12:23:31,566] Trial 1 finished with value: 1.91766510403085 and parameters: {'learning_rate': 1.5810058165067856e-06}. Best is trial 0 with value: 1.6532543369745685.\n",
|
688 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
689 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
690 |
+
]
|
691 |
+
},
|
692 |
+
{
|
693 |
+
"data": {
|
694 |
+
"text/html": [
|
695 |
+
"\n",
|
696 |
+
" <div>\n",
|
697 |
+
" \n",
|
698 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
699 |
+
" [620/620 03:28, Epoch 5/5]\n",
|
700 |
+
" </div>\n",
|
701 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
702 |
+
" <thead>\n",
|
703 |
+
" <tr style=\"text-align: left;\">\n",
|
704 |
+
" <th>Epoch</th>\n",
|
705 |
+
" <th>Training Loss</th>\n",
|
706 |
+
" <th>Validation Loss</th>\n",
|
707 |
+
" <th>Mse</th>\n",
|
708 |
+
" <th>Rmse</th>\n",
|
709 |
+
" <th>Mae</th>\n",
|
710 |
+
" <th>R2</th>\n",
|
711 |
+
" <th>Smape</th>\n",
|
712 |
+
" </tr>\n",
|
713 |
+
" </thead>\n",
|
714 |
+
" <tbody>\n",
|
715 |
+
" <tr>\n",
|
716 |
+
" <td>1</td>\n",
|
717 |
+
" <td>0.311500</td>\n",
|
718 |
+
" <td>4.090590</td>\n",
|
719 |
+
" <td>4.090590</td>\n",
|
720 |
+
" <td>2.022521</td>\n",
|
721 |
+
" <td>1.229977</td>\n",
|
722 |
+
" <td>-0.119932</td>\n",
|
723 |
+
" <td>50.514507</td>\n",
|
724 |
+
" </tr>\n",
|
725 |
+
" <tr>\n",
|
726 |
+
" <td>2</td>\n",
|
727 |
+
" <td>2.652800</td>\n",
|
728 |
+
" <td>4.852318</td>\n",
|
729 |
+
" <td>4.852319</td>\n",
|
730 |
+
" <td>2.202798</td>\n",
|
731 |
+
" <td>1.465739</td>\n",
|
732 |
+
" <td>-0.328480</td>\n",
|
733 |
+
" <td>54.715651</td>\n",
|
734 |
+
" </tr>\n",
|
735 |
+
" <tr>\n",
|
736 |
+
" <td>3</td>\n",
|
737 |
+
" <td>24.626400</td>\n",
|
738 |
+
" <td>3.331610</td>\n",
|
739 |
+
" <td>3.331610</td>\n",
|
740 |
+
" <td>1.825270</td>\n",
|
741 |
+
" <td>1.143937</td>\n",
|
742 |
+
" <td>0.087863</td>\n",
|
743 |
+
" <td>51.898420</td>\n",
|
744 |
+
" </tr>\n",
|
745 |
+
" <tr>\n",
|
746 |
+
" <td>4</td>\n",
|
747 |
+
" <td>0.289600</td>\n",
|
748 |
+
" <td>2.353773</td>\n",
|
749 |
+
" <td>2.353773</td>\n",
|
750 |
+
" <td>1.534201</td>\n",
|
751 |
+
" <td>1.079125</td>\n",
|
752 |
+
" <td>0.355578</td>\n",
|
753 |
+
" <td>55.779856</td>\n",
|
754 |
+
" </tr>\n",
|
755 |
+
" <tr>\n",
|
756 |
+
" <td>5</td>\n",
|
757 |
+
" <td>0.001400</td>\n",
|
758 |
+
" <td>2.629261</td>\n",
|
759 |
+
" <td>2.629261</td>\n",
|
760 |
+
" <td>1.621500</td>\n",
|
761 |
+
" <td>1.166006</td>\n",
|
762 |
+
" <td>0.280154</td>\n",
|
763 |
+
" <td>57.977718</td>\n",
|
764 |
+
" </tr>\n",
|
765 |
+
" </tbody>\n",
|
766 |
+
"</table><p>"
|
767 |
+
],
|
768 |
+
"text/plain": [
|
769 |
+
"<IPython.core.display.HTML object>"
|
770 |
+
]
|
771 |
+
},
|
772 |
+
"metadata": {},
|
773 |
+
"output_type": "display_data"
|
774 |
+
},
|
775 |
+
{
|
776 |
+
"name": "stderr",
|
777 |
+
"output_type": "stream",
|
778 |
+
"text": [
|
779 |
+
"[I 2025-01-09 12:27:05,020] Trial 2 finished with value: 1.6214995462309338 and parameters: {'learning_rate': 2.479942619764035e-05}. Best is trial 2 with value: 1.6214995462309338.\n",
|
780 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
781 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
782 |
+
]
|
783 |
+
},
|
784 |
+
{
|
785 |
+
"data": {
|
786 |
+
"text/html": [
|
787 |
+
"\n",
|
788 |
+
" <div>\n",
|
789 |
+
" \n",
|
790 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
791 |
+
" [620/620 03:25, Epoch 5/5]\n",
|
792 |
+
" </div>\n",
|
793 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
794 |
+
" <thead>\n",
|
795 |
+
" <tr style=\"text-align: left;\">\n",
|
796 |
+
" <th>Epoch</th>\n",
|
797 |
+
" <th>Training Loss</th>\n",
|
798 |
+
" <th>Validation Loss</th>\n",
|
799 |
+
" <th>Mse</th>\n",
|
800 |
+
" <th>Rmse</th>\n",
|
801 |
+
" <th>Mae</th>\n",
|
802 |
+
" <th>R2</th>\n",
|
803 |
+
" <th>Smape</th>\n",
|
804 |
+
" </tr>\n",
|
805 |
+
" </thead>\n",
|
806 |
+
" <tbody>\n",
|
807 |
+
" <tr>\n",
|
808 |
+
" <td>1</td>\n",
|
809 |
+
" <td>0.008000</td>\n",
|
810 |
+
" <td>3.590378</td>\n",
|
811 |
+
" <td>3.590379</td>\n",
|
812 |
+
" <td>1.894829</td>\n",
|
813 |
+
" <td>1.149898</td>\n",
|
814 |
+
" <td>0.017017</td>\n",
|
815 |
+
" <td>46.445611</td>\n",
|
816 |
+
" </tr>\n",
|
817 |
+
" <tr>\n",
|
818 |
+
" <td>2</td>\n",
|
819 |
+
" <td>2.704000</td>\n",
|
820 |
+
" <td>3.476464</td>\n",
|
821 |
+
" <td>3.476464</td>\n",
|
822 |
+
" <td>1.864528</td>\n",
|
823 |
+
" <td>1.125000</td>\n",
|
824 |
+
" <td>0.048205</td>\n",
|
825 |
+
" <td>47.319812</td>\n",
|
826 |
+
" </tr>\n",
|
827 |
+
" <tr>\n",
|
828 |
+
" <td>3</td>\n",
|
829 |
+
" <td>32.099300</td>\n",
|
830 |
+
" <td>3.543669</td>\n",
|
831 |
+
" <td>3.543668</td>\n",
|
832 |
+
" <td>1.882463</td>\n",
|
833 |
+
" <td>1.123369</td>\n",
|
834 |
+
" <td>0.029805</td>\n",
|
835 |
+
" <td>47.717217</td>\n",
|
836 |
+
" </tr>\n",
|
837 |
+
" <tr>\n",
|
838 |
+
" <td>4</td>\n",
|
839 |
+
" <td>0.058200</td>\n",
|
840 |
+
" <td>3.590872</td>\n",
|
841 |
+
" <td>3.590872</td>\n",
|
842 |
+
" <td>1.894960</td>\n",
|
843 |
+
" <td>1.142273</td>\n",
|
844 |
+
" <td>0.016882</td>\n",
|
845 |
+
" <td>48.410091</td>\n",
|
846 |
+
" </tr>\n",
|
847 |
+
" <tr>\n",
|
848 |
+
" <td>5</td>\n",
|
849 |
+
" <td>0.084600</td>\n",
|
850 |
+
" <td>3.600572</td>\n",
|
851 |
+
" <td>3.600573</td>\n",
|
852 |
+
" <td>1.897517</td>\n",
|
853 |
+
" <td>1.145824</td>\n",
|
854 |
+
" <td>0.014226</td>\n",
|
855 |
+
" <td>48.548377</td>\n",
|
856 |
+
" </tr>\n",
|
857 |
+
" </tbody>\n",
|
858 |
+
"</table><p>"
|
859 |
+
],
|
860 |
+
"text/plain": [
|
861 |
+
"<IPython.core.display.HTML object>"
|
862 |
+
]
|
863 |
+
},
|
864 |
+
"metadata": {},
|
865 |
+
"output_type": "display_data"
|
866 |
+
},
|
867 |
+
{
|
868 |
+
"name": "stderr",
|
869 |
+
"output_type": "stream",
|
870 |
+
"text": [
|
871 |
+
"[I 2025-01-09 12:30:33,965] Trial 3 finished with value: 1.8975174797770824 and parameters: {'learning_rate': 1.1750268648920993e-06}. Best is trial 2 with value: 1.6214995462309338.\n",
|
872 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
873 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
874 |
+
]
|
875 |
+
},
|
876 |
+
{
|
877 |
+
"data": {
|
878 |
+
"text/html": [
|
879 |
+
"\n",
|
880 |
+
" <div>\n",
|
881 |
+
" \n",
|
882 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
883 |
+
" [620/620 03:27, Epoch 5/5]\n",
|
884 |
+
" </div>\n",
|
885 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
886 |
+
" <thead>\n",
|
887 |
+
" <tr style=\"text-align: left;\">\n",
|
888 |
+
" <th>Epoch</th>\n",
|
889 |
+
" <th>Training Loss</th>\n",
|
890 |
+
" <th>Validation Loss</th>\n",
|
891 |
+
" <th>Mse</th>\n",
|
892 |
+
" <th>Rmse</th>\n",
|
893 |
+
" <th>Mae</th>\n",
|
894 |
+
" <th>R2</th>\n",
|
895 |
+
" <th>Smape</th>\n",
|
896 |
+
" </tr>\n",
|
897 |
+
" </thead>\n",
|
898 |
+
" <tbody>\n",
|
899 |
+
" <tr>\n",
|
900 |
+
" <td>1</td>\n",
|
901 |
+
" <td>0.085600</td>\n",
|
902 |
+
" <td>3.761341</td>\n",
|
903 |
+
" <td>3.761341</td>\n",
|
904 |
+
" <td>1.939418</td>\n",
|
905 |
+
" <td>1.156432</td>\n",
|
906 |
+
" <td>-0.029790</td>\n",
|
907 |
+
" <td>46.601269</td>\n",
|
908 |
+
" </tr>\n",
|
909 |
+
" <tr>\n",
|
910 |
+
" <td>2</td>\n",
|
911 |
+
" <td>2.913400</td>\n",
|
912 |
+
" <td>3.756832</td>\n",
|
913 |
+
" <td>3.756831</td>\n",
|
914 |
+
" <td>1.938255</td>\n",
|
915 |
+
" <td>1.238454</td>\n",
|
916 |
+
" <td>-0.028555</td>\n",
|
917 |
+
" <td>49.874967</td>\n",
|
918 |
+
" </tr>\n",
|
919 |
+
" <tr>\n",
|
920 |
+
" <td>3</td>\n",
|
921 |
+
" <td>32.276600</td>\n",
|
922 |
+
" <td>3.654472</td>\n",
|
923 |
+
" <td>3.654473</td>\n",
|
924 |
+
" <td>1.911668</td>\n",
|
925 |
+
" <td>1.135091</td>\n",
|
926 |
+
" <td>-0.000531</td>\n",
|
927 |
+
" <td>48.732340</td>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" <tr>\n",
|
930 |
+
" <td>4</td>\n",
|
931 |
+
" <td>0.083000</td>\n",
|
932 |
+
" <td>3.665871</td>\n",
|
933 |
+
" <td>3.665871</td>\n",
|
934 |
+
" <td>1.914646</td>\n",
|
935 |
+
" <td>1.162767</td>\n",
|
936 |
+
" <td>-0.003652</td>\n",
|
937 |
+
" <td>49.439710</td>\n",
|
938 |
+
" </tr>\n",
|
939 |
+
" <tr>\n",
|
940 |
+
" <td>5</td>\n",
|
941 |
+
" <td>0.055800</td>\n",
|
942 |
+
" <td>3.610057</td>\n",
|
943 |
+
" <td>3.610057</td>\n",
|
944 |
+
" <td>1.900015</td>\n",
|
945 |
+
" <td>1.183222</td>\n",
|
946 |
+
" <td>0.011629</td>\n",
|
947 |
+
" <td>49.474382</td>\n",
|
948 |
+
" </tr>\n",
|
949 |
+
" </tbody>\n",
|
950 |
+
"</table><p>"
|
951 |
+
],
|
952 |
+
"text/plain": [
|
953 |
+
"<IPython.core.display.HTML object>"
|
954 |
+
]
|
955 |
+
},
|
956 |
+
"metadata": {},
|
957 |
+
"output_type": "display_data"
|
958 |
+
},
|
959 |
+
{
|
960 |
+
"name": "stderr",
|
961 |
+
"output_type": "stream",
|
962 |
+
"text": [
|
963 |
+
"[I 2025-01-09 12:34:05,271] Trial 4 finished with value: 1.9000149676084739 and parameters: {'learning_rate': 2.308984942228097e-06}. Best is trial 2 with value: 1.6214995462309338.\n"
|
964 |
+
]
|
965 |
+
}
|
966 |
+
],
|
967 |
+
"source": [
|
968 |
+
"best_trial = hp_trainer.hyperparameter_search(\n",
|
969 |
+
" direction=\"minimize\",\n",
|
970 |
+
" backend=\"optuna\",\n",
|
971 |
+
" hp_space=optuna_hp_space,\n",
|
972 |
+
" n_trials=5,\n",
|
973 |
+
" compute_objective=lambda x: x['eval_rmse'],\n",
|
974 |
+
")"
|
975 |
+
]
|
976 |
+
},
|
977 |
+
{
|
978 |
+
"cell_type": "code",
|
979 |
+
"execution_count": 26,
|
980 |
+
"id": "c9c39fa6-3f84-4082-879a-7efd5d21e174",
|
981 |
+
"metadata": {},
|
982 |
+
"outputs": [
|
983 |
+
{
|
984 |
+
"data": {
|
985 |
+
"text/plain": [
|
986 |
+
"BestRun(run_id='2', objective=1.6214995462309338, hyperparameters={'learning_rate': 2.479942619764035e-05}, run_summary=None)"
|
987 |
+
]
|
988 |
+
},
|
989 |
+
"execution_count": 26,
|
990 |
+
"metadata": {},
|
991 |
+
"output_type": "execute_result"
|
992 |
+
}
|
993 |
+
],
|
994 |
+
"source": [
|
995 |
+
"best_trial"
|
996 |
+
]
|
997 |
+
},
|
998 |
+
{
|
999 |
+
"cell_type": "markdown",
|
1000 |
+
"id": "f511f354-5c3e-4c62-8063-f769a6c1b9ca",
|
1001 |
+
"metadata": {},
|
1002 |
+
"source": [
|
1003 |
+
"### Fit and upload the best Model\n",
|
1004 |
+
"We re-fit the model with the best hyperparameters in accordaince with this [forum post](https://discuss.huggingface.co/t/how-to-save-the-best-trials-model-using-trainer-hyperparameter-search/8783/4)"
|
1005 |
+
]
|
1006 |
+
},
|
1007 |
+
{
|
1008 |
+
"cell_type": "code",
|
1009 |
+
"execution_count": 27,
|
1010 |
+
"id": "ad4e4cf2-286e-4b4a-9c3b-2024be1769b8",
|
1011 |
+
"metadata": {},
|
1012 |
+
"outputs": [
|
1013 |
+
{
|
1014 |
+
"name": "stderr",
|
1015 |
+
"output_type": "stream",
|
1016 |
+
"text": [
|
1017 |
+
"Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
1018 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
1019 |
+
"/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
1020 |
+
" warnings.warn(\n"
|
1021 |
+
]
|
1022 |
+
}
|
1023 |
+
],
|
1024 |
+
"source": [
|
1025 |
+
"best_trainer = Trainer(\n",
|
1026 |
+
" model=model_init(None),\n",
|
1027 |
+
" args=gen_training_args({**best_trial.hyperparameters}),\n",
|
1028 |
+
" train_dataset=tokenized_dataset[\"train\"],\n",
|
1029 |
+
" eval_dataset=tokenized_dataset[\"test\"],\n",
|
1030 |
+
" compute_metrics=compute_metrics_for_regression,\n",
|
1031 |
+
")"
|
1032 |
+
]
|
1033 |
+
},
|
1034 |
+
{
|
1035 |
+
"cell_type": "code",
|
1036 |
+
"execution_count": 28,
|
1037 |
+
"id": "a031f566-4be1-440a-8481-f23e609ac3b3",
|
1038 |
+
"metadata": {},
|
1039 |
+
"outputs": [
|
1040 |
+
{
|
1041 |
+
"data": {
|
1042 |
+
"text/html": [
|
1043 |
+
"\n",
|
1044 |
+
" <div>\n",
|
1045 |
+
" \n",
|
1046 |
+
" <progress value='620' max='620' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1047 |
+
" [620/620 03:25, Epoch 5/5]\n",
|
1048 |
+
" </div>\n",
|
1049 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1050 |
+
" <thead>\n",
|
1051 |
+
" <tr style=\"text-align: left;\">\n",
|
1052 |
+
" <th>Epoch</th>\n",
|
1053 |
+
" <th>Training Loss</th>\n",
|
1054 |
+
" <th>Validation Loss</th>\n",
|
1055 |
+
" <th>Mse</th>\n",
|
1056 |
+
" <th>Rmse</th>\n",
|
1057 |
+
" <th>Mae</th>\n",
|
1058 |
+
" <th>R2</th>\n",
|
1059 |
+
" <th>Smape</th>\n",
|
1060 |
+
" </tr>\n",
|
1061 |
+
" </thead>\n",
|
1062 |
+
" <tbody>\n",
|
1063 |
+
" <tr>\n",
|
1064 |
+
" <td>1</td>\n",
|
1065 |
+
" <td>0.115200</td>\n",
|
1066 |
+
" <td>4.084211</td>\n",
|
1067 |
+
" <td>4.084211</td>\n",
|
1068 |
+
" <td>2.020943</td>\n",
|
1069 |
+
" <td>1.219903</td>\n",
|
1070 |
+
" <td>-0.118186</td>\n",
|
1071 |
+
" <td>49.023473</td>\n",
|
1072 |
+
" </tr>\n",
|
1073 |
+
" <tr>\n",
|
1074 |
+
" <td>2</td>\n",
|
1075 |
+
" <td>1.239000</td>\n",
|
1076 |
+
" <td>3.803578</td>\n",
|
1077 |
+
" <td>3.803578</td>\n",
|
1078 |
+
" <td>1.950276</td>\n",
|
1079 |
+
" <td>1.289222</td>\n",
|
1080 |
+
" <td>-0.041354</td>\n",
|
1081 |
+
" <td>52.775413</td>\n",
|
1082 |
+
" </tr>\n",
|
1083 |
+
" <tr>\n",
|
1084 |
+
" <td>3</td>\n",
|
1085 |
+
" <td>27.825600</td>\n",
|
1086 |
+
" <td>3.245966</td>\n",
|
1087 |
+
" <td>3.245967</td>\n",
|
1088 |
+
" <td>1.801657</td>\n",
|
1089 |
+
" <td>1.102216</td>\n",
|
1090 |
+
" <td>0.111311</td>\n",
|
1091 |
+
" <td>51.747030</td>\n",
|
1092 |
+
" </tr>\n",
|
1093 |
+
" <tr>\n",
|
1094 |
+
" <td>4</td>\n",
|
1095 |
+
" <td>0.000100</td>\n",
|
1096 |
+
" <td>2.413429</td>\n",
|
1097 |
+
" <td>2.413429</td>\n",
|
1098 |
+
" <td>1.553521</td>\n",
|
1099 |
+
" <td>1.081085</td>\n",
|
1100 |
+
" <td>0.339245</td>\n",
|
1101 |
+
" <td>52.221513</td>\n",
|
1102 |
+
" </tr>\n",
|
1103 |
+
" <tr>\n",
|
1104 |
+
" <td>5</td>\n",
|
1105 |
+
" <td>0.166600</td>\n",
|
1106 |
+
" <td>2.462405</td>\n",
|
1107 |
+
" <td>2.462406</td>\n",
|
1108 |
+
" <td>1.569205</td>\n",
|
1109 |
+
" <td>1.182182</td>\n",
|
1110 |
+
" <td>0.325836</td>\n",
|
1111 |
+
" <td>56.614470</td>\n",
|
1112 |
+
" </tr>\n",
|
1113 |
+
" </tbody>\n",
|
1114 |
+
"</table><p>"
|
1115 |
+
],
|
1116 |
+
"text/plain": [
|
1117 |
+
"<IPython.core.display.HTML object>"
|
1118 |
+
]
|
1119 |
+
},
|
1120 |
+
"metadata": {},
|
1121 |
+
"output_type": "display_data"
|
1122 |
+
},
|
1123 |
+
{
|
1124 |
+
"data": {
|
1125 |
+
"text/plain": [
|
1126 |
+
"TrainOutput(global_step=620, training_loss=4.329616037725622, metrics={'train_runtime': 205.4329, 'train_samples_per_second': 11.999, 'train_steps_per_second': 3.018, 'total_flos': 3359849068769280.0, 'train_loss': 4.329616037725622, 'epoch': 5.0})"
|
1127 |
+
]
|
1128 |
+
},
|
1129 |
+
"execution_count": 28,
|
1130 |
+
"metadata": {},
|
1131 |
+
"output_type": "execute_result"
|
1132 |
+
}
|
1133 |
+
],
|
1134 |
+
"source": [
|
1135 |
+
"best_trainer.train() "
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "code",
|
1140 |
+
"execution_count": 29,
|
1141 |
+
"id": "09a0f8e2-7986-4171-902e-c08bcd5d6088",
|
1142 |
+
"metadata": {},
|
1143 |
+
"outputs": [
|
1144 |
+
{
|
1145 |
+
"data": {
|
1146 |
+
"text/html": [
|
1147 |
+
"\n",
|
1148 |
+
" <div>\n",
|
1149 |
+
" \n",
|
1150 |
+
" <progress value='14' max='14' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1151 |
+
" [14/14 00:01]\n",
|
1152 |
+
" </div>\n",
|
1153 |
+
" "
|
1154 |
+
],
|
1155 |
+
"text/plain": [
|
1156 |
+
"<IPython.core.display.HTML object>"
|
1157 |
+
]
|
1158 |
+
},
|
1159 |
+
"metadata": {},
|
1160 |
+
"output_type": "display_data"
|
1161 |
+
},
|
1162 |
+
{
|
1163 |
+
"data": {
|
1164 |
+
"text/plain": [
|
1165 |
+
"{'eval_loss': 2.4624054431915283,\n",
|
1166 |
+
" 'eval_mse': 2.4624056816101074,\n",
|
1167 |
+
" 'eval_rmse': 1.5692054300218654,\n",
|
1168 |
+
" 'eval_mae': 1.182181715965271,\n",
|
1169 |
+
" 'eval_r2': 0.325836181640625,\n",
|
1170 |
+
" 'eval_smape': 56.61447048187256,\n",
|
1171 |
+
" 'eval_runtime': 1.3489,\n",
|
1172 |
+
" 'eval_samples_per_second': 40.774,\n",
|
1173 |
+
" 'eval_steps_per_second': 10.379,\n",
|
1174 |
+
" 'epoch': 5.0}"
|
1175 |
+
]
|
1176 |
+
},
|
1177 |
+
"execution_count": 29,
|
1178 |
+
"metadata": {},
|
1179 |
+
"output_type": "execute_result"
|
1180 |
+
}
|
1181 |
+
],
|
1182 |
+
"source": [
|
1183 |
+
"best_trainer.evaluate()"
|
1184 |
+
]
|
1185 |
+
},
|
1186 |
+
{
|
1187 |
+
"cell_type": "code",
|
1188 |
+
"execution_count": 30,
|
1189 |
+
"id": "6360b8dd-c456-4c0f-a79f-b1fb7a99ad19",
|
1190 |
+
"metadata": {},
|
1191 |
+
"outputs": [
|
1192 |
+
{
|
1193 |
+
"data": {
|
1194 |
+
"application/vnd.jupyter.widget-view+json": {
|
1195 |
+
"model_id": "49576b7f5f6b4ea781dd6198df4f33f7",
|
1196 |
+
"version_major": 2,
|
1197 |
+
"version_minor": 0
|
1198 |
+
},
|
1199 |
+
"text/plain": [
|
1200 |
+
"events.out.tfevents.1736455080.bazzite: 0%| | 0.00/40.0 [00:00<?, ?B/s]"
|
1201 |
+
]
|
1202 |
+
},
|
1203 |
+
"metadata": {},
|
1204 |
+
"output_type": "display_data"
|
1205 |
+
},
|
1206 |
+
{
|
1207 |
+
"data": {
|
1208 |
+
"text/plain": [
|
1209 |
+
"CommitInfo(commit_url='https://huggingface.co/Forecast-ing/modernBERT-content-regression/commit/16f1dc87782b2735f8fef84a5b10807b6cbe5565', commit_message='End of training', commit_description='', oid='16f1dc87782b2735f8fef84a5b10807b6cbe5565', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Forecast-ing/modernBERT-content-regression', endpoint='https://huggingface.co', repo_type='model', repo_id='Forecast-ing/modernBERT-content-regression'), pr_revision=None, pr_num=None)"
|
1210 |
+
]
|
1211 |
+
},
|
1212 |
+
"execution_count": 30,
|
1213 |
+
"metadata": {},
|
1214 |
+
"output_type": "execute_result"
|
1215 |
+
}
|
1216 |
+
],
|
1217 |
+
"source": [
|
1218 |
+
"tokenizer.save_pretrained(\"modernBERT-content-regression\")\n",
|
1219 |
+
"best_trainer.create_model_card()\n",
|
1220 |
+
"best_trainer.push_to_hub()"
|
1221 |
+
]
|
1222 |
+
},
|
1223 |
+
{
|
1224 |
+
"cell_type": "code",
|
1225 |
+
"execution_count": null,
|
1226 |
+
"id": "2c9971d5-5db6-48ac-be23-4131f75a961c",
|
1227 |
+
"metadata": {},
|
1228 |
+
"outputs": [],
|
1229 |
+
"source": []
|
1230 |
+
},
|
1231 |
+
{
|
1232 |
+
"cell_type": "code",
|
1233 |
+
"execution_count": null,
|
1234 |
+
"id": "083db2d4-2601-44ef-aa87-5f4a1d66f8e3",
|
1235 |
+
"metadata": {},
|
1236 |
+
"outputs": [],
|
1237 |
+
"source": []
|
1238 |
+
},
|
1239 |
+
{
|
1240 |
+
"cell_type": "code",
|
1241 |
+
"execution_count": null,
|
1242 |
+
"id": "9db9a38d-653d-4b71-8cc2-e0f65ae065bd",
|
1243 |
+
"metadata": {},
|
1244 |
+
"outputs": [],
|
1245 |
+
"source": []
|
1246 |
+
}
|
1247 |
+
],
|
1248 |
+
"metadata": {
|
1249 |
+
"kernelspec": {
|
1250 |
+
"display_name": "Python 3 (ipykernel)",
|
1251 |
+
"language": "python",
|
1252 |
+
"name": "python3"
|
1253 |
+
},
|
1254 |
+
"language_info": {
|
1255 |
+
"codemirror_mode": {
|
1256 |
+
"name": "ipython",
|
1257 |
+
"version": 3
|
1258 |
+
},
|
1259 |
+
"file_extension": ".py",
|
1260 |
+
"mimetype": "text/x-python",
|
1261 |
+
"name": "python",
|
1262 |
+
"nbconvert_exporter": "python",
|
1263 |
+
"pygments_lexer": "ipython3",
|
1264 |
+
"version": "3.12.8"
|
1265 |
+
}
|
1266 |
+
},
|
1267 |
+
"nbformat": 4,
|
1268 |
+
"nbformat_minor": 5
|
1269 |
+
}
|
pyproject.toml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "modernbert-content-regression"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.12"
|
7 |
+
dependencies = [
|
8 |
+
"accelerate>=1.2.1",
|
9 |
+
"catboost>=1.2.7",
|
10 |
+
"datasets>=3.2.0",
|
11 |
+
"evaluate>=0.4.3",
|
12 |
+
"hf-transfer>=0.1.8",
|
13 |
+
"optuna>=4.1.0",
|
14 |
+
"python-dotenv>=1.0.1",
|
15 |
+
"scikit-learn>=1.6.0",
|
16 |
+
"tensorboardx>=2.6.2.2",
|
17 |
+
"transformers",
|
18 |
+
]
|
19 |
+
|
20 |
+
[tool.uv.sources]
|
21 |
+
tinytroupe = { git = "https://github.com/microsoft/TinyTroupe.git" }
|
22 |
+
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "6e0515e99c39444caae39472ee1b2fd76ece32f1" }
|
23 |
+
|
24 |
+
[dependency-groups]
|
25 |
+
dev = [
|
26 |
+
"ruff>=0.8.4",
|
27 |
+
]
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|