{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "51a9bb64-969a-4fc0-aa76-4bd42b08c21a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from catboost import CatBoostRegressor, Pool\n", "from datasets import load_dataset\n", "from dotenv import load_dotenv\n", "from huggingface_hub import HfFolder, login\n", "from sklearn.metrics import (\n", " mean_absolute_error,\n", " mean_squared_error,\n", " r2_score,\n", " root_mean_squared_error,\n", ")\n", "from transformers import (\n", " AutoModelForSequenceClassification,\n", " AutoTokenizer,\n", " Trainer,\n", " TrainingArguments,\n", ")\n", "\n", "load_dotenv()" ] }, { "cell_type": "code", "execution_count": 2, "id": "9b53c782-5dbb-4dd6-b541-8e4fab3f3ddf", "metadata": {}, "outputs": [], "source": [ "login(token=os.getenv(\"HUGGINGFACE_API_KEY\"))" ] }, { "cell_type": "markdown", "id": "87f58ec1-231d-4c17-8af2-c366af55e375", "metadata": {}, "source": [ "### Dataset prep" ] }, { "cell_type": "code", "execution_count": 3, "id": "0ea777d8-988b-421a-8d76-b0f9256ab61b", "metadata": {}, "outputs": [], "source": [ "raw_dataset = load_dataset(\"Forecast-ing/email-clickthrough\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "0c8441e7-1606-4c61-8f6c-34ebe1e107c0", "metadata": {}, "outputs": [], "source": [ "raw_dataset = raw_dataset.rename_column(\"label\", \"labels\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "201b806d-6c94-4053-98b6-4d22dcdda08a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3292" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_dataset[\"train\"].to_pandas()[\"text\"].str.len().max()" ] }, { "cell_type": "code", "execution_count": 6, "id": "cbd4d6ec-b293-49f1-925f-239549dab61e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.2427007299270073" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(raw_dataset[\"train\"].to_pandas()[\"text\"].str.len() > 2048).mean()" ] }, { "cell_type": "code", "execution_count": 7, "id": "1101c038-3f83-4055-b938-3861ac43cf8f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "count 548.000000\n", "mean 2.879635\n", "std 2.423870\n", "min 0.450000\n", "25% 1.510000\n", "50% 2.025000\n", "75% 3.267500\n", "max 25.370000\n", "Name: labels, dtype: float64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_dataset[\"train\"].to_pandas()[\"labels\"].describe()" ] }, { "cell_type": "code", "execution_count": 8, "id": "7546577f-4d7f-41b5-a68f-095fc0e8eec4", "metadata": {}, "outputs": [], "source": [ "raw_dataset = raw_dataset[\"train\"].train_test_split(test_size=0.1, seed=1)" ] }, { "cell_type": "code", "execution_count": 9, "id": "a0c28ab6-20f8-47dc-8ee9-179ea15830e0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train dataset size: 493\n", "Test dataset size: 55\n" ] } ], "source": [ "print(f\"Train dataset size: {len(raw_dataset['train'])}\")\n", "print(f\"Test dataset size: {len(raw_dataset['test'])}\")" ] }, { "cell_type": "markdown", "id": "a2c3e7c3-e31d-4d8f-8d7d-e62050a9ae9d", "metadata": {}, "source": [ "### Catboost Benchmark" ] }, { "cell_type": "code", "execution_count": 10, "id": "01aaea26-e1df-4493-b9cd-732c3b7a76a9", "metadata": {}, "outputs": [], "source": [ "catboost_train = raw_dataset[\"train\"].to_pandas()\n", "catboost_test = raw_dataset[\"test\"].to_pandas()" ] }, { "cell_type": "code", "execution_count": 11, "id": "0243e07d-69ba-41e5-b54d-d0f4988bbf9f", "metadata": {}, "outputs": [], "source": [ "text_columns = [\"text\"]\n", "label = \"labels\"" ] }, { "cell_type": "code", "execution_count": 12, "id": "ba17040c-5882-47dc-a8af-a7557356840f", "metadata": {}, "outputs": [], "source": [ "train_pool = Pool(\n", " data=catboost_train[text_columns],\n", " label=catboost_train[label],\n", " text_features=text_columns,\n", ")\n", "test_pool = Pool(\n", " data=catboost_test[text_columns],\n", " label=catboost_test[label],\n", " text_features=text_columns,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "d8b3768e-6f30-41ce-a209-bd915a997d8a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.045569\n", "0:\tlearn: 2.4332854\ttest: 1.8670741\tbest: 1.8670741 (0)\ttotal: 60.5ms\tremaining: 1m\n", "100:\tlearn: 1.4972558\ttest: 1.6247590\tbest: 1.6048404 (59)\ttotal: 2.5s\tremaining: 22.2s\n", "200:\tlearn: 1.1104040\ttest: 1.6015944\tbest: 1.5975296 (197)\ttotal: 4.91s\tremaining: 19.5s\n", "300:\tlearn: 0.8568033\ttest: 1.6102309\tbest: 1.5975296 (197)\ttotal: 7.33s\tremaining: 17s\n", "400:\tlearn: 0.7096792\ttest: 1.6090190\tbest: 1.5975296 (197)\ttotal: 9.72s\tremaining: 14.5s\n", "500:\tlearn: 0.6056532\ttest: 1.6083240\tbest: 1.5975296 (197)\ttotal: 12.1s\tremaining: 12s\n", "600:\tlearn: 0.5298016\ttest: 1.6175366\tbest: 1.5975296 (197)\ttotal: 14.5s\tremaining: 9.64s\n", "700:\tlearn: 0.4701467\ttest: 1.6262668\tbest: 1.5975296 (197)\ttotal: 16.9s\tremaining: 7.23s\n", "800:\tlearn: 0.4233732\ttest: 1.6199203\tbest: 1.5975296 (197)\ttotal: 19.4s\tremaining: 4.81s\n", "900:\tlearn: 0.3837074\ttest: 1.6104091\tbest: 1.5975296 (197)\ttotal: 21.8s\tremaining: 2.39s\n", "999:\tlearn: 0.3501113\ttest: 1.6131207\tbest: 1.5975296 (197)\ttotal: 24.2s\tremaining: 0us\n", "\n", "bestTest = 1.597529566\n", "bestIteration = 197\n", "\n", "Shrink model to first 198 iterations.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostRegressor(loss_function=\"RMSE\", verbose=100)\n", "\n", "model.fit(train_pool, eval_set=test_pool)" ] }, { "cell_type": "code", "execution_count": 14, "id": "837b22a8-241d-49b3-a1ae-915893121319", "metadata": {}, "outputs": [], "source": [ "y_pred = model.predict(test_pool)\n", "y_val = catboost_test[label]" ] }, { "cell_type": "code", "execution_count": 15, "id": "478521bf-85be-49a1-8461-020587f146d2", "metadata": {}, "outputs": [], "source": [ "def smape(y_true, y_pred):\n", " return 100 * np.mean(\n", " 2 * np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred))\n", " )\n", "\n", "\n", "def calculate_metrics(y_val, y_pred):\n", " mse = mean_squared_error(y_val, y_pred)\n", " rmse = np.sqrt(mse)\n", " mae = mean_absolute_error(y_val, y_pred)\n", " r2 = r2_score(y_val, y_pred)\n", " smape_value = smape(y_val, y_pred)\n", " return {\n", " \"mse\": mse,\n", " \"rmse\": rmse,\n", " \"mae\": mae,\n", " \"r2\": r2,\n", " \"smape\": smape_value,\n", " }" ] }, { "cell_type": "code", "execution_count": 16, "id": "91ea95e2-1818-45a6-8725-3a1353cb5b97", "metadata": {}, "outputs": [], "source": [ "catboost_metrics = calculate_metrics(y_val, y_pred)" ] }, { "cell_type": "code", "execution_count": 17, "id": "e28e359e-f69c-4ee8-9bbd-e7afafafcd26", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mse': 2.552100633998035,\n", " 'rmse': 1.5975295408843102,\n", " 'mae': 1.1439370629666958,\n", " 'r2': 0.30127932054387174,\n", " 'smape': 37.63064694052479}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "catboost_metrics" ] }, { "cell_type": "markdown", "id": "7afac97a-1e69-47e4-9ecd-ece7ebe7b48f", "metadata": {}, "source": [ "### Fine Tuning Modern Bert" ] }, { "cell_type": "code", "execution_count": 18, "id": "031df047-2c18-4ec9-a498-596a7cf965b7", "metadata": {}, "outputs": [], "source": [ "model_id = \"answerdotai/ModernBERT-base\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "tokenizer.model_max_length = 2048\n", "\n", "def tokenize(batch):\n", " return tokenizer(\n", " batch[\"text\"], padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n", " )" ] }, { "cell_type": "code", "execution_count": 19, "id": "bbb711e8-8da6-401b-a2bd-c1637731f6c9", "metadata": {}, "outputs": [], "source": [ "tokenized_dataset = raw_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])" ] }, { "cell_type": "code", "execution_count": 20, "id": "ca4c2942-bf9e-4579-82b9-654fbee85b54", "metadata": {}, "outputs": [], "source": [ "def model_init(trial):\n", " model = AutoModelForSequenceClassification.from_pretrained(\n", " model_id, num_labels=1, ignore_mismatched_sizes=True, problem_type=\"regression\"\n", " )\n", " return model" ] }, { "cell_type": "code", "execution_count": 21, "id": "a2005499-e139-4151-9ebe-2759710149b1", "metadata": {}, "outputs": [], "source": [ "def gen_training_args(additional_args={}):\n", " default_args = {\n", " \"output_dir\": \"./modernBERT-content-regression\",\n", " \"per_device_eval_batch_size\": 4,\n", " \"per_device_train_batch_size\": 4,\n", " \"num_train_epochs\": 5,\n", " \"bf16\": True, # bfloat16 training\n", " \"optim\": \"adamw_torch_fused\", # improved optimizer\n", " \"logging_strategy\": \"steps\",\n", " \"logging_steps\": 1,\n", " \"evaluation_strategy\": \"epoch\",\n", " \"save_strategy\": \"epoch\",\n", " \"save_total_limit\": 1,\n", " \"metric_for_best_model\": \"rmse\",\n", " \"greater_is_better\": False,\n", " \"report_to\": \"tensorboard\",\n", " \"push_to_hub\": True,\n", " \"hub_private_repo\": True,\n", " \"hub_strategy\": \"every_save\",\n", " \"hub_token\": HfFolder.get_token(),\n", " }\n", " training_args = TrainingArguments(**default_args, **additional_args)\n", " return training_args" ] }, { "cell_type": "code", "execution_count": 22, "id": "e7e0ee17-9a1c-4789-8a4b-d2469a88837b", "metadata": {}, "outputs": [], "source": [ "def compute_metrics_for_regression(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = predictions.reshape(-1, 1)\n", " results = calculate_metrics(labels, predictions)\n", " return results\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "10981164-fffc-4128-89ae-41c9238074cd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "/tmp/ipykernel_22314/2727960756.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " hp_trainer = Trainer(\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "hp_trainer = Trainer(\n", " model=None,\n", " args=gen_training_args(),\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"test\"],\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics_for_regression,\n", " model_init=model_init,\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "261a4988-e5f9-469f-b8ae-ef51ae6a95df", "metadata": {}, "outputs": [], "source": [ "def optuna_hp_space(trial):\n", " return {\n", " \"learning_rate\": trial.suggest_float(\"learning_rate\", 5e-7, 5e-5, log=True),\n", " }" ] }, { "cell_type": "code", "execution_count": 25, "id": "7e7fb4f6-4eb4-4a4c-931d-a97c1614a5b9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:16:25,726] A new study created in memory with name: no-name-2f3f9073-d130-4bb1-9447-7262f2b7bd75\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [620/620 03:27, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.2380004.5730084.5730082.1384591.324540-0.25201054.242009
23.7685004.0934524.0934522.0232281.458057-0.12071653.770840
327.6610003.3618753.3618741.8335411.1266700.07957752.641284
40.0923002.7594592.7594591.6611621.0400740.24450853.009331
50.0203002.7332502.7332501.6532541.0786530.25168454.187167

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:19:57,000] Trial 0 finished with value: 1.6532543369745685 and parameters: {'learning_rate': 1.9437267223645173e-05}. Best is trial 0 with value: 1.6532543369745685.\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [620/620 03:30, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.0335003.7307573.7307571.9315171.167591-0.02141646.438679
23.0211003.5324183.5324201.8794731.1710510.03288548.273236
332.4544003.6709443.6709441.9159711.159171-0.00504148.529482
40.0743003.6905463.6905461.9210791.179955-0.01040749.107727
50.0988003.6774393.6774391.9176651.188619-0.00681949.251461

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:23:31,566] Trial 1 finished with value: 1.91766510403085 and parameters: {'learning_rate': 1.5810058165067856e-06}. Best is trial 0 with value: 1.6532543369745685.\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [620/620 03:28, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.3115004.0905904.0905902.0225211.229977-0.11993250.514507
22.6528004.8523184.8523192.2027981.465739-0.32848054.715651
324.6264003.3316103.3316101.8252701.1439370.08786351.898420
40.2896002.3537732.3537731.5342011.0791250.35557855.779856
50.0014002.6292612.6292611.6215001.1660060.28015457.977718

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:27:05,020] Trial 2 finished with value: 1.6214995462309338 and parameters: {'learning_rate': 2.479942619764035e-05}. Best is trial 2 with value: 1.6214995462309338.\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [620/620 03:25, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.0080003.5903783.5903791.8948291.1498980.01701746.445611
22.7040003.4764643.4764641.8645281.1250000.04820547.319812
332.0993003.5436693.5436681.8824631.1233690.02980547.717217
40.0582003.5908723.5908721.8949601.1422730.01688248.410091
50.0846003.6005723.6005731.8975171.1458240.01422648.548377

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:30:33,965] Trial 3 finished with value: 1.8975174797770824 and parameters: {'learning_rate': 1.1750268648920993e-06}. Best is trial 2 with value: 1.6214995462309338.\n", "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [620/620 03:27, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.0856003.7613413.7613411.9394181.156432-0.02979046.601269
22.9134003.7568323.7568311.9382551.238454-0.02855549.874967
332.2766003.6544723.6544731.9116681.135091-0.00053148.732340
40.0830003.6658713.6658711.9146461.162767-0.00365249.439710
50.0558003.6100573.6100571.9000151.1832220.01162949.474382

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-01-09 12:34:05,271] Trial 4 finished with value: 1.9000149676084739 and parameters: {'learning_rate': 2.308984942228097e-06}. Best is trial 2 with value: 1.6214995462309338.\n" ] } ], "source": [ "best_trial = hp_trainer.hyperparameter_search(\n", " direction=\"minimize\",\n", " backend=\"optuna\",\n", " hp_space=optuna_hp_space,\n", " n_trials=5,\n", " compute_objective=lambda x: x['eval_rmse'],\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "id": "c9c39fa6-3f84-4082-879a-7efd5d21e174", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BestRun(run_id='2', objective=1.6214995462309338, hyperparameters={'learning_rate': 2.479942619764035e-05}, run_summary=None)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_trial" ] }, { "cell_type": "markdown", "id": "f511f354-5c3e-4c62-8063-f769a6c1b9ca", "metadata": {}, "source": [ "### Fit and upload the best Model\n", "We re-fit the model with the best hyperparameters in accordaince with this [forum post](https://discuss.huggingface.co/t/how-to-save-the-best-trials-model-using-trainer-hyperparameter-search/8783/4)" ] }, { "cell_type": "code", "execution_count": 27, "id": "ad4e4cf2-286e-4b4a-9c3b-2024be1769b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "/var/home/robin/Development/modernbert-content-regression/.venv/lib/python3.12/site-packages/transformers/training_args.py:1573: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "best_trainer = Trainer(\n", " model=model_init(None),\n", " args=gen_training_args({**best_trial.hyperparameters}),\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"test\"],\n", " compute_metrics=compute_metrics_for_regression,\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "id": "a031f566-4be1-440a-8481-f23e609ac3b3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [620/620 03:25, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossMseRmseMaeR2Smape
10.1152004.0842114.0842112.0209431.219903-0.11818649.023473
21.2390003.8035783.8035781.9502761.289222-0.04135452.775413
327.8256003.2459663.2459671.8016571.1022160.11131151.747030
40.0001002.4134292.4134291.5535211.0810850.33924552.221513
50.1666002.4624052.4624061.5692051.1821820.32583656.614470

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=620, training_loss=4.329616037725622, metrics={'train_runtime': 205.4329, 'train_samples_per_second': 11.999, 'train_steps_per_second': 3.018, 'total_flos': 3359849068769280.0, 'train_loss': 4.329616037725622, 'epoch': 5.0})" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_trainer.train() " ] }, { "cell_type": "code", "execution_count": 29, "id": "09a0f8e2-7986-4171-902e-c08bcd5d6088", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [14/14 00:01]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 2.4624054431915283,\n", " 'eval_mse': 2.4624056816101074,\n", " 'eval_rmse': 1.5692054300218654,\n", " 'eval_mae': 1.182181715965271,\n", " 'eval_r2': 0.325836181640625,\n", " 'eval_smape': 56.61447048187256,\n", " 'eval_runtime': 1.3489,\n", " 'eval_samples_per_second': 40.774,\n", " 'eval_steps_per_second': 10.379,\n", " 'epoch': 5.0}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": 30, "id": "6360b8dd-c456-4c0f-a79f-b1fb7a99ad19", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "49576b7f5f6b4ea781dd6198df4f33f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "events.out.tfevents.1736455080.bazzite: 0%| | 0.00/40.0 [00:00