Varine commited on
Commit
2a50343
·
verified ·
1 Parent(s): 86d2baf

Upload Transformers.ipynb

Browse files

the main model of the transformers

Files changed (1) hide show
  1. Transformers.ipynb +1634 -0
Transformers.ipynb ADDED
@@ -0,0 +1,1634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c3af7c60-ba26-4f75-bbe9-664347299dca",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Defaulting to user installation because normal site-packages is not writeable\n",
14
+ "Collecting transformers\n",
15
+ " Downloading transformers-4.39.1-py3-none-any.whl.metadata (134 kB)\n",
16
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
17
+ "\u001b[?25hCollecting datasets\n",
18
+ " Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)\n",
19
+ "Collecting accelerate\n",
20
+ " Downloading accelerate-0.28.0-py3-none-any.whl.metadata (18 kB)\n",
21
+ "Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from transformers) (3.6.0)\n",
22
+ "Collecting huggingface-hub<1.0,>=0.19.3 (from transformers)\n",
23
+ " Downloading huggingface_hub-0.22.1-py3-none-any.whl.metadata (12 kB)\n",
24
+ "Requirement already satisfied: numpy>=1.17 in ./.local/lib/python3.10/site-packages (from transformers) (1.25.2)\n",
25
+ "Requirement already satisfied: packaging>=20.0 in /usr/lib/python3/dist-packages (from transformers) (21.3)\n",
26
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
27
+ "Collecting regex!=2019.12.17 (from transformers)\n",
28
+ " Downloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n",
29
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
30
+ "\u001b[?25hRequirement already satisfied: requests in ./.local/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
31
+ "Collecting tokenizers<0.19,>=0.14 (from transformers)\n",
32
+ " Downloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
33
+ "Collecting safetensors>=0.4.1 (from transformers)\n",
34
+ " Downloading safetensors-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n",
35
+ "Requirement already satisfied: tqdm>=4.27 in ./.local/lib/python3.10/site-packages (from transformers) (4.66.1)\n",
36
+ "Collecting pyarrow>=12.0.0 (from datasets)\n",
37
+ " Downloading pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.0 kB)\n",
38
+ "Collecting pyarrow-hotfix (from datasets)\n",
39
+ " Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)\n",
40
+ "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
41
+ " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
42
+ "Requirement already satisfied: pandas in /usr/lib/python3/dist-packages (from datasets) (1.3.5)\n",
43
+ "Collecting xxhash (from datasets)\n",
44
+ " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
45
+ "Collecting multiprocess (from datasets)\n",
46
+ " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
47
+ "Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets)\n",
48
+ " Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)\n",
49
+ "Collecting aiohttp (from datasets)\n",
50
+ " Downloading aiohttp-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.4 kB)\n",
51
+ "Requirement already satisfied: psutil in /usr/lib/python3/dist-packages (from accelerate) (5.9.0)\n",
52
+ "Requirement already satisfied: torch>=1.10.0 in /usr/lib/python3/dist-packages (from accelerate) (2.0.1)\n",
53
+ "Collecting aiosignal>=1.1.2 (from aiohttp->datasets)\n",
54
+ " Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)\n",
55
+ "Requirement already satisfied: attrs>=17.3.0 in ./.local/lib/python3.10/site-packages (from aiohttp->datasets) (23.1.0)\n",
56
+ "Collecting frozenlist>=1.1.1 (from aiohttp->datasets)\n",
57
+ " Downloading frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
58
+ "Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)\n",
59
+ " Downloading multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)\n",
60
+ "Collecting yarl<2.0,>=1.0 (from aiohttp->datasets)\n",
61
+ " Downloading yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)\n",
62
+ "Collecting async-timeout<5.0,>=4.0 (from aiohttp->datasets)\n",
63
+ " Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)\n",
64
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.8.0)\n",
65
+ "Requirement already satisfied: charset-normalizer<4,>=2 in ./.local/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n",
66
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->transformers) (3.3)\n",
67
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/lib/python3/dist-packages (from requests->transformers) (1.26.5)\n",
68
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->transformers) (2020.6.20)\n",
69
+ "Downloading transformers-4.39.1-py3-none-any.whl (8.8 MB)\n",
70
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.8/8.8 MB\u001b[0m \u001b[31m208.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
71
+ "\u001b[?25hDownloading datasets-2.18.0-py3-none-any.whl (510 kB)\n",
72
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m80.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
73
+ "\u001b[?25hDownloading accelerate-0.28.0-py3-none-any.whl (290 kB)\n",
74
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m290.1/290.1 kB\u001b[0m \u001b[31m59.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
75
+ "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
76
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m24.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
77
+ "\u001b[?25hDownloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n",
78
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m170.9/170.9 kB\u001b[0m \u001b[31m33.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
79
+ "\u001b[?25hDownloading aiohttp-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n",
80
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m136.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
81
+ "\u001b[?25hDownloading huggingface_hub-0.22.1-py3-none-any.whl (388 kB)\n",
82
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m388.6/388.6 kB\u001b[0m \u001b[31m66.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
83
+ "\u001b[?25hDownloading pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n",
84
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m123.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
85
+ "\u001b[?25hDownloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (773 kB)\n",
86
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m774.0/774.0 kB\u001b[0m \u001b[31m97.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
87
+ "\u001b[?25hDownloading safetensors-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
88
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m125.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
89
+ "\u001b[?25hDownloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n",
90
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m194.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
91
+ "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
92
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
93
+ "\u001b[?25hDownloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n",
94
+ "Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
95
+ "\u001b[2K \u001b[90m��━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m50.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
96
+ "\u001b[?25hDownloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
97
+ "Downloading async_timeout-4.0.3-py3-none-any.whl (5.7 kB)\n",
98
+ "Downloading frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (239 kB)\n",
99
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m239.5/239.5 kB\u001b[0m \u001b[31m45.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
100
+ "\u001b[?25hDownloading multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (124 kB)\n",
101
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.3/124.3 kB\u001b[0m \u001b[31m29.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
102
+ "\u001b[?25hDownloading yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (301 kB)\n",
103
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.6/301.6 kB\u001b[0m \u001b[31m61.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
104
+ "\u001b[?25h\u001b[33mDEPRECATION: flatbuffers 1.12.1-git20200711.33e2d80-dfsg1-0.6 has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of flatbuffers or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
105
+ "\u001b[0mInstalling collected packages: xxhash, safetensors, regex, pyarrow-hotfix, pyarrow, multidict, fsspec, frozenlist, dill, async-timeout, yarl, multiprocess, huggingface-hub, aiosignal, tokenizers, aiohttp, accelerate, transformers, datasets\n",
106
+ "Successfully installed accelerate-0.28.0 aiohttp-3.9.3 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.18.0 dill-0.3.8 frozenlist-1.4.1 fsspec-2024.2.0 huggingface-hub-0.22.1 multidict-6.0.5 multiprocess-0.70.16 pyarrow-15.0.2 pyarrow-hotfix-0.6 regex-2023.12.25 safetensors-0.4.2 tokenizers-0.15.2 transformers-4.39.1 xxhash-3.4.1 yarl-1.9.4\n",
107
+ "\n",
108
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
109
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n"
110
+ ]
111
+ }
112
+ ],
113
+ "source": [
114
+ "! pip install transformers datasets accelerate"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 2,
120
+ "id": "0c24abf0-926e-4c37-9713-58dffe06ed03",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 3,
130
+ "id": "390d5322-3f72-49e5-b001-f66d943f0c2c",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "task = \"cola\"\n",
135
+ "model_checkpoint = \"distilbert-base-uncased\"\n",
136
+ "batch_size = 16"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 4,
142
+ "id": "bece75f9-a5a2-45a6-aef0-33a2fafd6262",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "from datasets import load_dataset, load_metric"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 5,
152
+ "id": "a3bfef60-bd97-434e-9b83-560687ad4c08",
153
+ "metadata": {},
154
+ "outputs": [
155
+ {
156
+ "data": {
157
+ "application/vnd.jupyter.widget-view+json": {
158
+ "model_id": "1316f9ea215b4c99b67f5278ac5061fd",
159
+ "version_major": 2,
160
+ "version_minor": 0
161
+ },
162
+ "text/plain": [
163
+ "Downloading readme: 0%| | 0.00/35.3k [00:00<?, ?B/s]"
164
+ ]
165
+ },
166
+ "metadata": {},
167
+ "output_type": "display_data"
168
+ },
169
+ {
170
+ "name": "stderr",
171
+ "output_type": "stream",
172
+ "text": [
173
+ "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.25.2\n",
174
+ " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
175
+ "Downloading data: 100%|██████████| 251k/251k [00:00<00:00, 1.00MB/s]\n",
176
+ "Downloading data: 100%|██████████| 37.6k/37.6k [00:00<00:00, 251kB/s]\n",
177
+ "Downloading data: 100%|██████████| 37.7k/37.7k [00:00<00:00, 242kB/s]\n"
178
+ ]
179
+ },
180
+ {
181
+ "data": {
182
+ "application/vnd.jupyter.widget-view+json": {
183
+ "model_id": "a77c4b8db75c41bfbc994e0ecaf908cc",
184
+ "version_major": 2,
185
+ "version_minor": 0
186
+ },
187
+ "text/plain": [
188
+ "Generating train split: 0%| | 0/8551 [00:00<?, ? examples/s]"
189
+ ]
190
+ },
191
+ "metadata": {},
192
+ "output_type": "display_data"
193
+ },
194
+ {
195
+ "data": {
196
+ "application/vnd.jupyter.widget-view+json": {
197
+ "model_id": "3876016d10e841b19a5653055fb4962b",
198
+ "version_major": 2,
199
+ "version_minor": 0
200
+ },
201
+ "text/plain": [
202
+ "Generating validation split: 0%| | 0/1043 [00:00<?, ? examples/s]"
203
+ ]
204
+ },
205
+ "metadata": {},
206
+ "output_type": "display_data"
207
+ },
208
+ {
209
+ "data": {
210
+ "application/vnd.jupyter.widget-view+json": {
211
+ "model_id": "b67c3bc5ae4242f5af5d9fc548ed578b",
212
+ "version_major": 2,
213
+ "version_minor": 0
214
+ },
215
+ "text/plain": [
216
+ "Generating test split: 0%| | 0/1063 [00:00<?, ? examples/s]"
217
+ ]
218
+ },
219
+ "metadata": {},
220
+ "output_type": "display_data"
221
+ },
222
+ {
223
+ "name": "stderr",
224
+ "output_type": "stream",
225
+ "text": [
226
+ "/tmp/ipykernel_1505/1389288479.py:3: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
227
+ " metric = load_metric('glue', actual_task)\n",
228
+ "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/load.py:756: FutureWarning: The repository for glue contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.18.0/metrics/glue/glue.py\n",
229
+ "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
230
+ "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n",
231
+ " warnings.warn(\n"
232
+ ]
233
+ },
234
+ {
235
+ "data": {
236
+ "application/vnd.jupyter.widget-view+json": {
237
+ "model_id": "d01b1cf183a94d019c84f09d3f282235",
238
+ "version_major": 2,
239
+ "version_minor": 0
240
+ },
241
+ "text/plain": [
242
+ "Downloading builder script: 0%| | 0.00/1.84k [00:00<?, ?B/s]"
243
+ ]
244
+ },
245
+ "metadata": {},
246
+ "output_type": "display_data"
247
+ }
248
+ ],
249
+ "source": [
250
+ "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n",
251
+ "dataset = load_dataset(\"glue\", actual_task)\n",
252
+ "metric = load_metric('glue', actual_task)"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 6,
258
+ "id": "33cd1a8c-7ff3-475a-a434-1e90fb72af98",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "import datasets\n",
263
+ "import random\n",
264
+ "import pandas as pd\n",
265
+ "from IPython.display import display, HTML\n",
266
+ "\n",
267
+ "def show_random_elements(dataset, num_examples=10):\n",
268
+ " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
269
+ " picks = []\n",
270
+ " for _ in range(num_examples):\n",
271
+ " pick = random.randint(0, len(dataset)-1)\n",
272
+ " while pick in picks:\n",
273
+ " pick = random.randint(0, len(dataset)-1)\n",
274
+ " picks.append(pick)\n",
275
+ " \n",
276
+ " df = pd.DataFrame(dataset[picks])\n",
277
+ " for column, typ in dataset.features.items():\n",
278
+ " if isinstance(typ, datasets.ClassLabel):\n",
279
+ " df[column] = df[column].transform(lambda i: typ.names[i])\n",
280
+ " display(HTML(df.to_html()))"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": 7,
286
+ "id": "0800efbd-8b6a-43b9-8359-4c546e1a3e2d",
287
+ "metadata": {},
288
+ "outputs": [
289
+ {
290
+ "data": {
291
+ "text/html": [
292
+ "<table border=\"1\" class=\"dataframe\">\n",
293
+ " <thead>\n",
294
+ " <tr style=\"text-align: right;\">\n",
295
+ " <th></th>\n",
296
+ " <th>sentence</th>\n",
297
+ " <th>label</th>\n",
298
+ " <th>idx</th>\n",
299
+ " </tr>\n",
300
+ " </thead>\n",
301
+ " <tbody>\n",
302
+ " <tr>\n",
303
+ " <th>0</th>\n",
304
+ " <td>Mary jumped the horse perfectly over the last fence.</td>\n",
305
+ " <td>acceptable</td>\n",
306
+ " <td>705</td>\n",
307
+ " </tr>\n",
308
+ " <tr>\n",
309
+ " <th>1</th>\n",
310
+ " <td>John taught new students English Syntax.</td>\n",
311
+ " <td>acceptable</td>\n",
312
+ " <td>3951</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <th>2</th>\n",
316
+ " <td>This doll is hard to see it.</td>\n",
317
+ " <td>unacceptable</td>\n",
318
+ " <td>5018</td>\n",
319
+ " </tr>\n",
320
+ " <tr>\n",
321
+ " <th>3</th>\n",
322
+ " <td>I whipped the eggs from a puddle into a froth.</td>\n",
323
+ " <td>unacceptable</td>\n",
324
+ " <td>2298</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <th>4</th>\n",
328
+ " <td>Bill wants John to leave.</td>\n",
329
+ " <td>acceptable</td>\n",
330
+ " <td>6157</td>\n",
331
+ " </tr>\n",
332
+ " <tr>\n",
333
+ " <th>5</th>\n",
334
+ " <td>John expect to must leave.</td>\n",
335
+ " <td>unacceptable</td>\n",
336
+ " <td>4481</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <th>6</th>\n",
340
+ " <td>Bill's mother saw him.</td>\n",
341
+ " <td>acceptable</td>\n",
342
+ " <td>7569</td>\n",
343
+ " </tr>\n",
344
+ " <tr>\n",
345
+ " <th>7</th>\n",
346
+ " <td>Once Janet left, Fred became all the crazier.</td>\n",
347
+ " <td>acceptable</td>\n",
348
+ " <td>226</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <th>8</th>\n",
352
+ " <td>He's too reliable a man.</td>\n",
353
+ " <td>acceptable</td>\n",
354
+ " <td>5440</td>\n",
355
+ " </tr>\n",
356
+ " <tr>\n",
357
+ " <th>9</th>\n",
358
+ " <td>I wonder if she used paints.</td>\n",
359
+ " <td>acceptable</td>\n",
360
+ " <td>7425</td>\n",
361
+ " </tr>\n",
362
+ " </tbody>\n",
363
+ "</table>"
364
+ ],
365
+ "text/plain": [
366
+ "<IPython.core.display.HTML object>"
367
+ ]
368
+ },
369
+ "metadata": {},
370
+ "output_type": "display_data"
371
+ }
372
+ ],
373
+ "source": [
374
+ "show_random_elements(dataset[\"train\"])"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": 8,
380
+ "id": "ce74eb02-1bf1-4ce9-b9f9-34ed0d7d1f8f",
381
+ "metadata": {},
382
+ "outputs": [
383
+ {
384
+ "data": {
385
+ "text/plain": [
386
+ "{'matthews_correlation': 0.0416070055112537}"
387
+ ]
388
+ },
389
+ "execution_count": 8,
390
+ "metadata": {},
391
+ "output_type": "execute_result"
392
+ }
393
+ ],
394
+ "source": [
395
+ "import numpy as np\n",
396
+ "\n",
397
+ "fake_preds = np.random.randint(0, 2, size=(64,))\n",
398
+ "fake_labels = np.random.randint(0, 2, size=(64,))\n",
399
+ "metric.compute(predictions=fake_preds, references=fake_labels)"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": 9,
405
+ "id": "f5bd6db5-8786-477b-89a6-7ca21414f4ec",
406
+ "metadata": {},
407
+ "outputs": [
408
+ {
409
+ "data": {
410
+ "application/vnd.jupyter.widget-view+json": {
411
+ "model_id": "9f5d7bb9f48b4c6b816427eeb8b5fe5d",
412
+ "version_major": 2,
413
+ "version_minor": 0
414
+ },
415
+ "text/plain": [
416
+ "tokenizer_config.json: 0%| | 0.00/28.0 [00:00<?, ?B/s]"
417
+ ]
418
+ },
419
+ "metadata": {},
420
+ "output_type": "display_data"
421
+ },
422
+ {
423
+ "data": {
424
+ "application/vnd.jupyter.widget-view+json": {
425
+ "model_id": "b6e68d7807c1445ab5554c7b6a838b73",
426
+ "version_major": 2,
427
+ "version_minor": 0
428
+ },
429
+ "text/plain": [
430
+ "config.json: 0%| | 0.00/483 [00:00<?, ?B/s]"
431
+ ]
432
+ },
433
+ "metadata": {},
434
+ "output_type": "display_data"
435
+ },
436
+ {
437
+ "data": {
438
+ "application/vnd.jupyter.widget-view+json": {
439
+ "model_id": "8bd709d5ebfe410ba0e4c8c0aa40f599",
440
+ "version_major": 2,
441
+ "version_minor": 0
442
+ },
443
+ "text/plain": [
444
+ "vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
445
+ ]
446
+ },
447
+ "metadata": {},
448
+ "output_type": "display_data"
449
+ },
450
+ {
451
+ "data": {
452
+ "application/vnd.jupyter.widget-view+json": {
453
+ "model_id": "c158c3cb051d4575b33c2ec22d9491b7",
454
+ "version_major": 2,
455
+ "version_minor": 0
456
+ },
457
+ "text/plain": [
458
+ "tokenizer.json: 0%| | 0.00/466k [00:00<?, ?B/s]"
459
+ ]
460
+ },
461
+ "metadata": {},
462
+ "output_type": "display_data"
463
+ }
464
+ ],
465
+ "source": [
466
+ "from transformers import AutoTokenizer\n",
467
+ " \n",
468
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": 10,
474
+ "id": "86da0827-3614-4f1c-969c-bc6c731225ab",
475
+ "metadata": {},
476
+ "outputs": [],
477
+ "source": [
478
+ "task_to_keys = {\n",
479
+ " \"cola\": (\"sentence\", None),\n",
480
+ " \"mnli\": (\"premise\", \"hypothesis\"),\n",
481
+ " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
482
+ " \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
483
+ " \"qnli\": (\"question\", \"sentence\"),\n",
484
+ " \"qqp\": (\"question1\", \"question2\"),\n",
485
+ " \"rte\": (\"sentence1\", \"sentence2\"),\n",
486
+ " \"sst2\": (\"sentence\", None),\n",
487
+ " \"stsb\": (\"sentence1\", \"sentence2\"),\n",
488
+ " \"wnli\": (\"sentence1\", \"sentence2\"),\n",
489
+ "}"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": 11,
495
+ "id": "ce31302c-0aae-40ca-b6f6-385303507eba",
496
+ "metadata": {},
497
+ "outputs": [
498
+ {
499
+ "name": "stdout",
500
+ "output_type": "stream",
501
+ "text": [
502
+ "Sentence: Our friends won't buy this analysis, let alone the next one we propose.\n"
503
+ ]
504
+ }
505
+ ],
506
+ "source": [
507
+ "sentence1_key, sentence2_key = task_to_keys[task]\n",
508
+ "if sentence2_key is None:\n",
509
+ " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
510
+ "else:\n",
511
+ " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
512
+ " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 12,
518
+ "id": "eefc459b-6833-4291-812a-65b6a6e29e71",
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "def preprocess_function(examples):\n",
523
+ " if sentence2_key is None:\n",
524
+ " return tokenizer(examples[sentence1_key], truncation=True)\n",
525
+ " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 13,
531
+ "id": "890f5781-8031-46cf-9ba3-c65f9ad29810",
532
+ "metadata": {},
533
+ "outputs": [
534
+ {
535
+ "data": {
536
+ "application/vnd.jupyter.widget-view+json": {
537
+ "model_id": "5f98cc317d1a45d0a8bd00c95e7ed505",
538
+ "version_major": 2,
539
+ "version_minor": 0
540
+ },
541
+ "text/plain": [
542
+ "Map: 0%| | 0/8551 [00:00<?, ? examples/s]"
543
+ ]
544
+ },
545
+ "metadata": {},
546
+ "output_type": "display_data"
547
+ },
548
+ {
549
+ "data": {
550
+ "application/vnd.jupyter.widget-view+json": {
551
+ "model_id": "4c73f2f8c73f473c836c96e31bbbbeae",
552
+ "version_major": 2,
553
+ "version_minor": 0
554
+ },
555
+ "text/plain": [
556
+ "Map: 0%| | 0/1043 [00:00<?, ? examples/s]"
557
+ ]
558
+ },
559
+ "metadata": {},
560
+ "output_type": "display_data"
561
+ },
562
+ {
563
+ "data": {
564
+ "application/vnd.jupyter.widget-view+json": {
565
+ "model_id": "45cb3ffddeee41da8ec9b5f83a93d076",
566
+ "version_major": 2,
567
+ "version_minor": 0
568
+ },
569
+ "text/plain": [
570
+ "Map: 0%| | 0/1063 [00:00<?, ? examples/s]"
571
+ ]
572
+ },
573
+ "metadata": {},
574
+ "output_type": "display_data"
575
+ }
576
+ ],
577
+ "source": [
578
+ "encoded_dataset = dataset.map(preprocess_function, batched=True)"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": 14,
584
+ "id": "656bfda6-45c8-4843-b7c3-f70e31578abe",
585
+ "metadata": {},
586
+ "outputs": [
587
+ {
588
+ "name": "stderr",
589
+ "output_type": "stream",
590
+ "text": [
591
+ "2024-03-27 11:00:29.468986: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
592
+ "2024-03-27 11:00:29.672421: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
593
+ "To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
594
+ ]
595
+ },
596
+ {
597
+ "data": {
598
+ "application/vnd.jupyter.widget-view+json": {
599
+ "model_id": "ff7f8f4314b14a43be4b599015552608",
600
+ "version_major": 2,
601
+ "version_minor": 0
602
+ },
603
+ "text/plain": [
604
+ "model.safetensors: 0%| | 0.00/268M [00:00<?, ?B/s]"
605
+ ]
606
+ },
607
+ "metadata": {},
608
+ "output_type": "display_data"
609
+ },
610
+ {
611
+ "name": "stderr",
612
+ "output_type": "stream",
613
+ "text": [
614
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
615
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
616
+ ]
617
+ }
618
+ ],
619
+ "source": [
620
+ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
621
+ "\n",
622
+ "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
623
+ "model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": 15,
629
+ "id": "6b50c21a-abaa-41b6-85b9-952540de64d1",
630
+ "metadata": {},
631
+ "outputs": [],
632
+ "source": [
633
+ "metric_name = \"pearson\" if task == \"stsb\" else \"matthews_correlation\" if task == \"cola\" else \"accuracy\"\n",
634
+ "\n",
635
+ "args = TrainingArguments(\n",
636
+ " \"test-glue\",\n",
637
+ " evaluation_strategy = \"epoch\",\n",
638
+ " save_strategy = \"epoch\",\n",
639
+ " learning_rate=2e-5,\n",
640
+ " per_device_train_batch_size=batch_size,\n",
641
+ " per_device_eval_batch_size=batch_size,\n",
642
+ " num_train_epochs=5,\n",
643
+ " weight_decay=0.01,\n",
644
+ " load_best_model_at_end=True,\n",
645
+ " metric_for_best_model=metric_name,\n",
646
+ ")"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": 19,
652
+ "id": "65c8eb57-9536-42cd-91d4-33536ce383f3",
653
+ "metadata": {},
654
+ "outputs": [],
655
+ "source": [
656
+ "def compute_metrics(eval_pred):\n",
657
+ " predictions, labels = eval_pred\n",
658
+ " if task != \"stsb\":\n",
659
+ " predictions = np.argmax(predictions, axis=1)\n",
660
+ " else:\n",
661
+ " predictions = predictions[:, 0]\n",
662
+ " return metric.compute(predictions=predictions, references=labels)"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": 20,
668
+ "id": "cb789ab8-0887-487b-9bfe-7c5e84aa66ec",
669
+ "metadata": {},
670
+ "outputs": [],
671
+ "source": [
672
+ "validation_key = \"validation_mismatched\" if task == \"mnli-mm\" else \"validation_matched\" if task == \"mnli\" else \"validation\"\n",
673
+ "trainer = Trainer(\n",
674
+ " model,\n",
675
+ " args,\n",
676
+ " train_dataset=encoded_dataset[\"train\"],\n",
677
+ " eval_dataset=encoded_dataset[validation_key],\n",
678
+ " tokenizer=tokenizer,\n",
679
+ " compute_metrics=compute_metrics\n",
680
+ ")"
681
+ ]
682
+ },
683
+ {
684
+ "cell_type": "code",
685
+ "execution_count": 21,
686
+ "id": "8985fb22-4809-46e0-a6ab-c7df3e2a1e89",
687
+ "metadata": {},
688
+ "outputs": [
689
+ {
690
+ "data": {
691
+ "text/html": [
692
+ "\n",
693
+ " <div>\n",
694
+ " \n",
695
+ " <progress value='2675' max='2675' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
696
+ " [2675/2675 01:14, Epoch 5/5]\n",
697
+ " </div>\n",
698
+ " <table border=\"1\" class=\"dataframe\">\n",
699
+ " <thead>\n",
700
+ " <tr style=\"text-align: left;\">\n",
701
+ " <th>Epoch</th>\n",
702
+ " <th>Training Loss</th>\n",
703
+ " <th>Validation Loss</th>\n",
704
+ " <th>Matthews Correlation</th>\n",
705
+ " </tr>\n",
706
+ " </thead>\n",
707
+ " <tbody>\n",
708
+ " <tr>\n",
709
+ " <td>1</td>\n",
710
+ " <td>0.519000</td>\n",
711
+ " <td>0.472218</td>\n",
712
+ " <td>0.430751</td>\n",
713
+ " </tr>\n",
714
+ " <tr>\n",
715
+ " <td>2</td>\n",
716
+ " <td>0.349800</td>\n",
717
+ " <td>0.502173</td>\n",
718
+ " <td>0.535758</td>\n",
719
+ " </tr>\n",
720
+ " <tr>\n",
721
+ " <td>3</td>\n",
722
+ " <td>0.238200</td>\n",
723
+ " <td>0.617800</td>\n",
724
+ " <td>0.541004</td>\n",
725
+ " </tr>\n",
726
+ " <tr>\n",
727
+ " <td>4</td>\n",
728
+ " <td>0.173400</td>\n",
729
+ " <td>0.744248</td>\n",
730
+ " <td>0.549477</td>\n",
731
+ " </tr>\n",
732
+ " <tr>\n",
733
+ " <td>5</td>\n",
734
+ " <td>0.127800</td>\n",
735
+ " <td>0.803236</td>\n",
736
+ " <td>0.550403</td>\n",
737
+ " </tr>\n",
738
+ " </tbody>\n",
739
+ "</table><p>"
740
+ ],
741
+ "text/plain": [
742
+ "<IPython.core.display.HTML object>"
743
+ ]
744
+ },
745
+ "metadata": {},
746
+ "output_type": "display_data"
747
+ },
748
+ {
749
+ "data": {
750
+ "text/plain": [
751
+ "TrainOutput(global_step=2675, training_loss=0.27159803158768986, metrics={'train_runtime': 75.2661, 'train_samples_per_second': 568.051, 'train_steps_per_second': 35.541, 'total_flos': 229000686898068.0, 'train_loss': 0.27159803158768986, 'epoch': 5.0})"
752
+ ]
753
+ },
754
+ "execution_count": 21,
755
+ "metadata": {},
756
+ "output_type": "execute_result"
757
+ }
758
+ ],
759
+ "source": [
760
+ "trainer.train()"
761
+ ]
762
+ },
763
+ {
764
+ "cell_type": "code",
765
+ "execution_count": 22,
766
+ "id": "e4106e5c-a37d-4e8f-b880-339e42daf57f",
767
+ "metadata": {},
768
+ "outputs": [
769
+ {
770
+ "data": {
771
+ "text/html": [
772
+ "\n",
773
+ " <div>\n",
774
+ " \n",
775
+ " <progress value='66' max='66' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
776
+ " [66/66 00:00]\n",
777
+ " </div>\n",
778
+ " "
779
+ ],
780
+ "text/plain": [
781
+ "<IPython.core.display.HTML object>"
782
+ ]
783
+ },
784
+ "metadata": {},
785
+ "output_type": "display_data"
786
+ },
787
+ {
788
+ "data": {
789
+ "text/plain": [
790
+ "{'eval_loss': 0.8032358288764954,\n",
791
+ " 'eval_matthews_correlation': 0.5504031254980248,\n",
792
+ " 'eval_runtime': 0.3257,\n",
793
+ " 'eval_samples_per_second': 3201.883,\n",
794
+ " 'eval_steps_per_second': 202.612,\n",
795
+ " 'epoch': 5.0}"
796
+ ]
797
+ },
798
+ "execution_count": 22,
799
+ "metadata": {},
800
+ "output_type": "execute_result"
801
+ }
802
+ ],
803
+ "source": [
804
+ "trainer.evaluate()"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": 23,
810
+ "id": "703d1296-ce54-4281-b7d3-d487e545343a",
811
+ "metadata": {},
812
+ "outputs": [
813
+ {
814
+ "name": "stderr",
815
+ "output_type": "stream",
816
+ "text": [
817
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
818
+ "To disable this warning, you can either:\n",
819
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
820
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
821
+ ]
822
+ },
823
+ {
824
+ "name": "stdout",
825
+ "output_type": "stream",
826
+ "text": [
827
+ "Defaulting to user installation because normal site-packages is not writeable\n",
828
+ "Collecting optuna\n",
829
+ " Downloading optuna-3.6.0-py3-none-any.whl.metadata (17 kB)\n",
830
+ "Collecting alembic>=1.5.0 (from optuna)\n",
831
+ " Downloading alembic-1.13.1-py3-none-any.whl.metadata (7.4 kB)\n",
832
+ "Collecting colorlog (from optuna)\n",
833
+ " Downloading colorlog-6.8.2-py3-none-any.whl.metadata (10 kB)\n",
834
+ "Requirement already satisfied: numpy in ./.local/lib/python3.10/site-packages (from optuna) (1.25.2)\n",
835
+ "Requirement already satisfied: packaging>=20.0 in /usr/lib/python3/dist-packages (from optuna) (21.3)\n",
836
+ "Collecting sqlalchemy>=1.3.0 (from optuna)\n",
837
+ " Downloading SQLAlchemy-2.0.29-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)\n",
838
+ "Requirement already satisfied: tqdm in ./.local/lib/python3.10/site-packages (from optuna) (4.66.1)\n",
839
+ "Requirement already satisfied: PyYAML in /usr/lib/python3/dist-packages (from optuna) (5.4.1)\n",
840
+ "Collecting Mako (from alembic>=1.5.0->optuna)\n",
841
+ " Downloading Mako-1.3.2-py3-none-any.whl.metadata (2.9 kB)\n",
842
+ "Requirement already satisfied: typing-extensions>=4 in ./.local/lib/python3.10/site-packages (from alembic>=1.5.0->optuna) (4.8.0)\n",
843
+ "Collecting greenlet!=0.4.17 (from sqlalchemy>=1.3.0->optuna)\n",
844
+ " Downloading greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (3.8 kB)\n",
845
+ "Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/lib/python3/dist-packages (from Mako->alembic>=1.5.0->optuna) (2.0.1)\n",
846
+ "Downloading optuna-3.6.0-py3-none-any.whl (379 kB)\n",
847
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m379.9/379.9 kB\u001b[0m \u001b[31m27.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
848
+ "\u001b[?25hDownloading alembic-1.13.1-py3-none-any.whl (233 kB)\n",
849
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m233.4/233.4 kB\u001b[0m \u001b[31m68.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
850
+ "\u001b[?25hDownloading SQLAlchemy-2.0.29-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
851
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m209.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
852
+ "\u001b[?25hDownloading colorlog-6.8.2-py3-none-any.whl (11 kB)\n",
853
+ "Downloading greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (616 kB)\n",
854
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m616.0/616.0 kB\u001b[0m \u001b[31m127.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
855
+ "\u001b[?25hDownloading Mako-1.3.2-py3-none-any.whl (78 kB)\n",
856
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.7/78.7 kB\u001b[0m \u001b[31m25.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
857
+ "\u001b[?25h\u001b[33mDEPRECATION: flatbuffers 1.12.1-git20200711.33e2d80-dfsg1-0.6 has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of flatbuffers or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
858
+ "\u001b[0mInstalling collected packages: Mako, greenlet, colorlog, sqlalchemy, alembic, optuna\n",
859
+ "Successfully installed Mako-1.3.2 alembic-1.13.1 colorlog-6.8.2 greenlet-3.0.3 optuna-3.6.0 sqlalchemy-2.0.29\n",
860
+ "\n",
861
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
862
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n"
863
+ ]
864
+ },
865
+ {
866
+ "name": "stderr",
867
+ "output_type": "stream",
868
+ "text": [
869
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
870
+ "To disable this warning, you can either:\n",
871
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
872
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
873
+ ]
874
+ },
875
+ {
876
+ "name": "stdout",
877
+ "output_type": "stream",
878
+ "text": [
879
+ "Defaulting to user installation because normal site-packages is not writeable\n",
880
+ "Collecting ray[tune]\n",
881
+ " Downloading ray-2.10.0-cp310-cp310-manylinux2014_x86_64.whl.metadata (13 kB)\n",
882
+ "Requirement already satisfied: click>=7.0 in /usr/lib/python3/dist-packages (from ray[tune]) (8.0.3)\n",
883
+ "Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from ray[tune]) (3.6.0)\n",
884
+ "Requirement already satisfied: jsonschema in ./.local/lib/python3.10/site-packages (from ray[tune]) (4.20.0)\n",
885
+ "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/lib/python3/dist-packages (from ray[tune]) (1.0.3)\n",
886
+ "Requirement already satisfied: packaging in /usr/lib/python3/dist-packages (from ray[tune]) (21.3)\n",
887
+ "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/lib/python3/dist-packages (from ray[tune]) (4.21.12)\n",
888
+ "Requirement already satisfied: pyyaml in /usr/lib/python3/dist-packages (from ray[tune]) (5.4.1)\n",
889
+ "Requirement already satisfied: aiosignal in ./.local/lib/python3.10/site-packages (from ray[tune]) (1.3.1)\n",
890
+ "Requirement already satisfied: frozenlist in ./.local/lib/python3.10/site-packages (from ray[tune]) (1.4.1)\n",
891
+ "Requirement already satisfied: requests in ./.local/lib/python3.10/site-packages (from ray[tune]) (2.31.0)\n",
892
+ "Requirement already satisfied: pandas in /usr/lib/python3/dist-packages (from ray[tune]) (1.3.5)\n",
893
+ "Collecting tensorboardX>=1.9 (from ray[tune])\n",
894
+ " Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)\n",
895
+ "Requirement already satisfied: pyarrow>=6.0.1 in ./.local/lib/python3.10/site-packages (from ray[tune]) (15.0.2)\n",
896
+ "Requirement already satisfied: fsspec in ./.local/lib/python3.10/site-packages (from ray[tune]) (2024.2.0)\n",
897
+ "Requirement already satisfied: numpy<2,>=1.16.6 in ./.local/lib/python3.10/site-packages (from pyarrow>=6.0.1->ray[tune]) (1.25.2)\n",
898
+ "Requirement already satisfied: attrs>=22.2.0 in ./.local/lib/python3.10/site-packages (from jsonschema->ray[tune]) (23.1.0)\n",
899
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in ./.local/lib/python3.10/site-packages (from jsonschema->ray[tune]) (2023.11.2)\n",
900
+ "Requirement already satisfied: referencing>=0.28.4 in ./.local/lib/python3.10/site-packages (from jsonschema->ray[tune]) (0.31.1)\n",
901
+ "Requirement already satisfied: rpds-py>=0.7.1 in ./.local/lib/python3.10/site-packages (from jsonschema->ray[tune]) (0.13.2)\n",
902
+ "Requirement already satisfied: charset-normalizer<4,>=2 in ./.local/lib/python3.10/site-packages (from requests->ray[tune]) (3.3.2)\n",
903
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->ray[tune]) (3.3)\n",
904
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/lib/python3/dist-packages (from requests->ray[tune]) (1.26.5)\n",
905
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->ray[tune]) (2020.6.20)\n",
906
+ "Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)\n",
907
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.7/101.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
908
+ "\u001b[?25hDownloading ray-2.10.0-cp310-cp310-manylinux2014_x86_64.whl (65.1 MB)\n",
909
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.1/65.1 MB\u001b[0m \u001b[31m97.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
910
+ "\u001b[?25h\u001b[33mDEPRECATION: flatbuffers 1.12.1-git20200711.33e2d80-dfsg1-0.6 has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of flatbuffers or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
911
+ "\u001b[0mInstalling collected packages: tensorboardX, ray\n",
912
+ "Successfully installed ray-2.10.0 tensorboardX-2.6.2.2\n",
913
+ "\n",
914
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
915
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n"
916
+ ]
917
+ }
918
+ ],
919
+ "source": [
920
+ "! pip install optuna\n",
921
+ "! pip install ray[tune]"
922
+ ]
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "execution_count": 24,
927
+ "id": "fae555d4-8640-4a81-9b49-4a9d9a5ab9b5",
928
+ "metadata": {},
929
+ "outputs": [],
930
+ "source": [
931
+ "def model_init():\n",
932
+ " return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)"
933
+ ]
934
+ },
935
+ {
936
+ "cell_type": "code",
937
+ "execution_count": 25,
938
+ "id": "ac0f793c-8418-48d1-9b37-41005f0095c3",
939
+ "metadata": {},
940
+ "outputs": [
941
+ {
942
+ "name": "stderr",
943
+ "output_type": "stream",
944
+ "text": [
945
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
946
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
947
+ " warnings.warn(\n",
948
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
949
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
950
+ ]
951
+ }
952
+ ],
953
+ "source": [
954
+ "trainer = Trainer(\n",
955
+ " model_init=model_init,\n",
956
+ " args=args,\n",
957
+ " train_dataset=encoded_dataset[\"train\"],\n",
958
+ " eval_dataset=encoded_dataset[validation_key],\n",
959
+ " tokenizer=tokenizer,\n",
960
+ " compute_metrics=compute_metrics\n",
961
+ ")"
962
+ ]
963
+ },
964
+ {
965
+ "cell_type": "code",
966
+ "execution_count": 26,
967
+ "id": "7d74518a-ebc0-43ac-accb-65c32d5ec118",
968
+ "metadata": {},
969
+ "outputs": [
970
+ {
971
+ "name": "stderr",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "[I 2024-03-27 11:07:46,609] A new study created in memory with name: no-name-f7c7ff48-4767-4715-9c09-9c4565193c42\n",
975
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
976
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
977
+ " warnings.warn(\n",
978
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
979
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
980
+ ]
981
+ },
982
+ {
983
+ "data": {
984
+ "text/html": [
985
+ "\n",
986
+ " <div>\n",
987
+ " \n",
988
+ " <progress value='2140' max='2140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
989
+ " [2140/2140 00:59, Epoch 4/4]\n",
990
+ " </div>\n",
991
+ " <table border=\"1\" class=\"dataframe\">\n",
992
+ " <thead>\n",
993
+ " <tr style=\"text-align: left;\">\n",
994
+ " <th>Epoch</th>\n",
995
+ " <th>Training Loss</th>\n",
996
+ " <th>Validation Loss</th>\n",
997
+ " <th>Matthews Correlation</th>\n",
998
+ " </tr>\n",
999
+ " </thead>\n",
1000
+ " <tbody>\n",
1001
+ " <tr>\n",
1002
+ " <td>1</td>\n",
1003
+ " <td>0.568600</td>\n",
1004
+ " <td>0.528286</td>\n",
1005
+ " <td>0.318150</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <td>2</td>\n",
1009
+ " <td>0.390500</td>\n",
1010
+ " <td>0.564842</td>\n",
1011
+ " <td>0.387962</td>\n",
1012
+ " </tr>\n",
1013
+ " <tr>\n",
1014
+ " <td>3</td>\n",
1015
+ " <td>0.237300</td>\n",
1016
+ " <td>0.725552</td>\n",
1017
+ " <td>0.436872</td>\n",
1018
+ " </tr>\n",
1019
+ " <tr>\n",
1020
+ " <td>4</td>\n",
1021
+ " <td>0.139100</td>\n",
1022
+ " <td>0.973828</td>\n",
1023
+ " <td>0.429154</td>\n",
1024
+ " </tr>\n",
1025
+ " </tbody>\n",
1026
+ "</table><p>"
1027
+ ],
1028
+ "text/plain": [
1029
+ "<IPython.core.display.HTML object>"
1030
+ ]
1031
+ },
1032
+ "metadata": {},
1033
+ "output_type": "display_data"
1034
+ },
1035
+ {
1036
+ "name": "stderr",
1037
+ "output_type": "stream",
1038
+ "text": [
1039
+ "[I 2024-03-27 11:08:46,135] Trial 0 finished with value: 0.42915398713994973 and parameters: {'learning_rate': 6.658969020177832e-05, 'num_train_epochs': 4, 'seed': 11, 'per_device_train_batch_size': 16}. Best is trial 0 with value: 0.42915398713994973.\n",
1040
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1041
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1042
+ " warnings.warn(\n",
1043
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1044
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1045
+ ]
1046
+ },
1047
+ {
1048
+ "data": {
1049
+ "text/html": [
1050
+ "\n",
1051
+ " <div>\n",
1052
+ " \n",
1053
+ " <progress value='402' max='402' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1054
+ " [402/402 00:26, Epoch 3/3]\n",
1055
+ " </div>\n",
1056
+ " <table border=\"1\" class=\"dataframe\">\n",
1057
+ " <thead>\n",
1058
+ " <tr style=\"text-align: left;\">\n",
1059
+ " <th>Epoch</th>\n",
1060
+ " <th>Training Loss</th>\n",
1061
+ " <th>Validation Loss</th>\n",
1062
+ " <th>Matthews Correlation</th>\n",
1063
+ " </tr>\n",
1064
+ " </thead>\n",
1065
+ " <tbody>\n",
1066
+ " <tr>\n",
1067
+ " <td>1</td>\n",
1068
+ " <td>No log</td>\n",
1069
+ " <td>0.531186</td>\n",
1070
+ " <td>0.332502</td>\n",
1071
+ " </tr>\n",
1072
+ " <tr>\n",
1073
+ " <td>2</td>\n",
1074
+ " <td>No log</td>\n",
1075
+ " <td>0.503717</td>\n",
1076
+ " <td>0.443275</td>\n",
1077
+ " </tr>\n",
1078
+ " <tr>\n",
1079
+ " <td>3</td>\n",
1080
+ " <td>No log</td>\n",
1081
+ " <td>0.507968</td>\n",
1082
+ " <td>0.439255</td>\n",
1083
+ " </tr>\n",
1084
+ " </tbody>\n",
1085
+ "</table><p>"
1086
+ ],
1087
+ "text/plain": [
1088
+ "<IPython.core.display.HTML object>"
1089
+ ]
1090
+ },
1091
+ "metadata": {},
1092
+ "output_type": "display_data"
1093
+ },
1094
+ {
1095
+ "name": "stderr",
1096
+ "output_type": "stream",
1097
+ "text": [
1098
+ "[I 2024-03-27 11:09:13,247] Trial 1 finished with value: 0.4392548203439382 and parameters: {'learning_rate': 1.1290628476063563e-05, 'num_train_epochs': 3, 'seed': 28, 'per_device_train_batch_size': 64}. Best is trial 1 with value: 0.4392548203439382.\n",
1099
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1100
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1101
+ " warnings.warn(\n",
1102
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1103
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1104
+ ]
1105
+ },
1106
+ {
1107
+ "data": {
1108
+ "text/html": [
1109
+ "\n",
1110
+ " <div>\n",
1111
+ " \n",
1112
+ " <progress value='8552' max='8552' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1113
+ " [8552/8552 03:23, Epoch 4/4]\n",
1114
+ " </div>\n",
1115
+ " <table border=\"1\" class=\"dataframe\">\n",
1116
+ " <thead>\n",
1117
+ " <tr style=\"text-align: left;\">\n",
1118
+ " <th>Epoch</th>\n",
1119
+ " <th>Training Loss</th>\n",
1120
+ " <th>Validation Loss</th>\n",
1121
+ " <th>Matthews Correlation</th>\n",
1122
+ " </tr>\n",
1123
+ " </thead>\n",
1124
+ " <tbody>\n",
1125
+ " <tr>\n",
1126
+ " <td>1</td>\n",
1127
+ " <td>0.531300</td>\n",
1128
+ " <td>0.566970</td>\n",
1129
+ " <td>0.414967</td>\n",
1130
+ " </tr>\n",
1131
+ " <tr>\n",
1132
+ " <td>2</td>\n",
1133
+ " <td>0.512400</td>\n",
1134
+ " <td>0.786295</td>\n",
1135
+ " <td>0.472533</td>\n",
1136
+ " </tr>\n",
1137
+ " <tr>\n",
1138
+ " <td>3</td>\n",
1139
+ " <td>0.381700</td>\n",
1140
+ " <td>0.904949</td>\n",
1141
+ " <td>0.502075</td>\n",
1142
+ " </tr>\n",
1143
+ " <tr>\n",
1144
+ " <td>4</td>\n",
1145
+ " <td>0.272600</td>\n",
1146
+ " <td>1.014711</td>\n",
1147
+ " <td>0.494873</td>\n",
1148
+ " </tr>\n",
1149
+ " </tbody>\n",
1150
+ "</table><p>"
1151
+ ],
1152
+ "text/plain": [
1153
+ "<IPython.core.display.HTML object>"
1154
+ ]
1155
+ },
1156
+ "metadata": {},
1157
+ "output_type": "display_data"
1158
+ },
1159
+ {
1160
+ "name": "stderr",
1161
+ "output_type": "stream",
1162
+ "text": [
1163
+ "[I 2024-03-27 11:12:37,216] Trial 2 finished with value: 0.4948726793760845 and parameters: {'learning_rate': 8.36801127282771e-06, 'num_train_epochs': 4, 'seed': 12, 'per_device_train_batch_size': 4}. Best is trial 2 with value: 0.4948726793760845.\n",
1164
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1165
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1166
+ " warnings.warn(\n",
1167
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1168
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1169
+ ]
1170
+ },
1171
+ {
1172
+ "data": {
1173
+ "text/html": [
1174
+ "\n",
1175
+ " <div>\n",
1176
+ " \n",
1177
+ " <progress value='536' max='536' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1178
+ " [536/536 00:21, Epoch 2/2]\n",
1179
+ " </div>\n",
1180
+ " <table border=\"1\" class=\"dataframe\">\n",
1181
+ " <thead>\n",
1182
+ " <tr style=\"text-align: left;\">\n",
1183
+ " <th>Epoch</th>\n",
1184
+ " <th>Training Loss</th>\n",
1185
+ " <th>Validation Loss</th>\n",
1186
+ " <th>Matthews Correlation</th>\n",
1187
+ " </tr>\n",
1188
+ " </thead>\n",
1189
+ " <tbody>\n",
1190
+ " <tr>\n",
1191
+ " <td>1</td>\n",
1192
+ " <td>No log</td>\n",
1193
+ " <td>0.479286</td>\n",
1194
+ " <td>0.436850</td>\n",
1195
+ " </tr>\n",
1196
+ " <tr>\n",
1197
+ " <td>2</td>\n",
1198
+ " <td>0.414800</td>\n",
1199
+ " <td>0.520329</td>\n",
1200
+ " <td>0.502552</td>\n",
1201
+ " </tr>\n",
1202
+ " </tbody>\n",
1203
+ "</table><p>"
1204
+ ],
1205
+ "text/plain": [
1206
+ "<IPython.core.display.HTML object>"
1207
+ ]
1208
+ },
1209
+ "metadata": {},
1210
+ "output_type": "display_data"
1211
+ },
1212
+ {
1213
+ "name": "stderr",
1214
+ "output_type": "stream",
1215
+ "text": [
1216
+ "[I 2024-03-27 11:12:59,219] Trial 3 finished with value: 0.5025517897100551 and parameters: {'learning_rate': 9.440074279431108e-05, 'num_train_epochs': 2, 'seed': 17, 'per_device_train_batch_size': 32}. Best is trial 3 with value: 0.5025517897100551.\n",
1217
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1218
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1219
+ " warnings.warn(\n",
1220
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1221
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1222
+ ]
1223
+ },
1224
+ {
1225
+ "data": {
1226
+ "text/html": [
1227
+ "\n",
1228
+ " <div>\n",
1229
+ " \n",
1230
+ " <progress value='535' max='535' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1231
+ " [535/535 00:14, Epoch 1/1]\n",
1232
+ " </div>\n",
1233
+ " <table border=\"1\" class=\"dataframe\">\n",
1234
+ " <thead>\n",
1235
+ " <tr style=\"text-align: left;\">\n",
1236
+ " <th>Epoch</th>\n",
1237
+ " <th>Training Loss</th>\n",
1238
+ " <th>Validation Loss</th>\n",
1239
+ " <th>Matthews Correlation</th>\n",
1240
+ " </tr>\n",
1241
+ " </thead>\n",
1242
+ " <tbody>\n",
1243
+ " <tr>\n",
1244
+ " <td>1</td>\n",
1245
+ " <td>0.615000</td>\n",
1246
+ " <td>0.603050</td>\n",
1247
+ " <td>0.000000</td>\n",
1248
+ " </tr>\n",
1249
+ " </tbody>\n",
1250
+ "</table><p>"
1251
+ ],
1252
+ "text/plain": [
1253
+ "<IPython.core.display.HTML object>"
1254
+ ]
1255
+ },
1256
+ "metadata": {},
1257
+ "output_type": "display_data"
1258
+ },
1259
+ {
1260
+ "name": "stderr",
1261
+ "output_type": "stream",
1262
+ "text": [
1263
+ "/usr/lib/python3/dist-packages/sklearn/metrics/_classification.py:846: RuntimeWarning: invalid value encountered in scalar divide\n",
1264
+ " mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)\n",
1265
+ "[I 2024-03-27 11:13:14,620] Trial 4 finished with value: 0.0 and parameters: {'learning_rate': 1.8300985987395685e-06, 'num_train_epochs': 1, 'seed': 13, 'per_device_train_batch_size': 16}. Best is trial 3 with value: 0.5025517897100551.\n",
1266
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1267
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1268
+ " warnings.warn(\n",
1269
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1270
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1271
+ ]
1272
+ },
1273
+ {
1274
+ "data": {
1275
+ "text/html": [
1276
+ "\n",
1277
+ " <div>\n",
1278
+ " \n",
1279
+ " <progress value='2138' max='10690' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1280
+ " [ 2138/10690 00:50 < 03:20, 42.59 it/s, Epoch 1/5]\n",
1281
+ " </div>\n",
1282
+ " <table border=\"1\" class=\"dataframe\">\n",
1283
+ " <thead>\n",
1284
+ " <tr style=\"text-align: left;\">\n",
1285
+ " <th>Epoch</th>\n",
1286
+ " <th>Training Loss</th>\n",
1287
+ " <th>Validation Loss</th>\n",
1288
+ " <th>Matthews Correlation</th>\n",
1289
+ " </tr>\n",
1290
+ " </thead>\n",
1291
+ " <tbody>\n",
1292
+ " <tr>\n",
1293
+ " <td>1</td>\n",
1294
+ " <td>0.535100</td>\n",
1295
+ " <td>0.573925</td>\n",
1296
+ " <td>0.380639</td>\n",
1297
+ " </tr>\n",
1298
+ " </tbody>\n",
1299
+ "</table><p>"
1300
+ ],
1301
+ "text/plain": [
1302
+ "<IPython.core.display.HTML object>"
1303
+ ]
1304
+ },
1305
+ "metadata": {},
1306
+ "output_type": "display_data"
1307
+ },
1308
+ {
1309
+ "name": "stderr",
1310
+ "output_type": "stream",
1311
+ "text": [
1312
+ "[I 2024-03-27 11:14:05,400] Trial 5 pruned. \n",
1313
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1314
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1315
+ ]
1316
+ },
1317
+ {
1318
+ "data": {
1319
+ "text/html": [
1320
+ "\n",
1321
+ " <div>\n",
1322
+ " \n",
1323
+ " <progress value='134' max='402' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1324
+ " [134/402 00:08 < 00:16, 16.04 it/s, Epoch 1/3]\n",
1325
+ " </div>\n",
1326
+ " <table border=\"1\" class=\"dataframe\">\n",
1327
+ " <thead>\n",
1328
+ " <tr style=\"text-align: left;\">\n",
1329
+ " <th>Epoch</th>\n",
1330
+ " <th>Training Loss</th>\n",
1331
+ " <th>Validation Loss</th>\n",
1332
+ " <th>Matthews Correlation</th>\n",
1333
+ " </tr>\n",
1334
+ " </thead>\n",
1335
+ " <tbody>\n",
1336
+ " <tr>\n",
1337
+ " <td>1</td>\n",
1338
+ " <td>No log</td>\n",
1339
+ " <td>0.598633</td>\n",
1340
+ " <td>0.000000</td>\n",
1341
+ " </tr>\n",
1342
+ " </tbody>\n",
1343
+ "</table><p>"
1344
+ ],
1345
+ "text/plain": [
1346
+ "<IPython.core.display.HTML object>"
1347
+ ]
1348
+ },
1349
+ "metadata": {},
1350
+ "output_type": "display_data"
1351
+ },
1352
+ {
1353
+ "name": "stderr",
1354
+ "output_type": "stream",
1355
+ "text": [
1356
+ "/usr/lib/python3/dist-packages/sklearn/metrics/_classification.py:846: RuntimeWarning: invalid value encountered in scalar divide\n",
1357
+ " mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)\n",
1358
+ "[I 2024-03-27 11:14:14,176] Trial 6 pruned. \n",
1359
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1360
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1361
+ ]
1362
+ },
1363
+ {
1364
+ "data": {
1365
+ "text/html": [
1366
+ "\n",
1367
+ " <div>\n",
1368
+ " \n",
1369
+ " <progress value='1069' max='1069' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1370
+ " [1069/1069 00:26, Epoch 1/1]\n",
1371
+ " </div>\n",
1372
+ " <table border=\"1\" class=\"dataframe\">\n",
1373
+ " <thead>\n",
1374
+ " <tr style=\"text-align: left;\">\n",
1375
+ " <th>Epoch</th>\n",
1376
+ " <th>Training Loss</th>\n",
1377
+ " <th>Validation Loss</th>\n",
1378
+ " <th>Matthews Correlation</th>\n",
1379
+ " </tr>\n",
1380
+ " </thead>\n",
1381
+ " <tbody>\n",
1382
+ " <tr>\n",
1383
+ " <td>1</td>\n",
1384
+ " <td>0.503800</td>\n",
1385
+ " <td>0.527398</td>\n",
1386
+ " <td>0.379181</td>\n",
1387
+ " </tr>\n",
1388
+ " </tbody>\n",
1389
+ "</table><p>"
1390
+ ],
1391
+ "text/plain": [
1392
+ "<IPython.core.display.HTML object>"
1393
+ ]
1394
+ },
1395
+ "metadata": {},
1396
+ "output_type": "display_data"
1397
+ },
1398
+ {
1399
+ "name": "stderr",
1400
+ "output_type": "stream",
1401
+ "text": [
1402
+ "[I 2024-03-27 11:14:40,919] Trial 7 finished with value: 0.37918052306046424 and parameters: {'learning_rate': 1.0727131909090178e-05, 'num_train_epochs': 1, 'seed': 37, 'per_device_train_batch_size': 8}. Best is trial 3 with value: 0.5025517897100551.\n",
1403
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1404
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1405
+ " warnings.warn(\n",
1406
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1407
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1408
+ ]
1409
+ },
1410
+ {
1411
+ "data": {
1412
+ "text/html": [
1413
+ "\n",
1414
+ " <div>\n",
1415
+ " \n",
1416
+ " <progress value='2138' max='2138' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1417
+ " [2138/2138 00:52, Epoch 2/2]\n",
1418
+ " </div>\n",
1419
+ " <table border=\"1\" class=\"dataframe\">\n",
1420
+ " <thead>\n",
1421
+ " <tr style=\"text-align: left;\">\n",
1422
+ " <th>Epoch</th>\n",
1423
+ " <th>Training Loss</th>\n",
1424
+ " <th>Validation Loss</th>\n",
1425
+ " <th>Matthews Correlation</th>\n",
1426
+ " </tr>\n",
1427
+ " </thead>\n",
1428
+ " <tbody>\n",
1429
+ " <tr>\n",
1430
+ " <td>1</td>\n",
1431
+ " <td>0.528000</td>\n",
1432
+ " <td>0.511369</td>\n",
1433
+ " <td>0.389045</td>\n",
1434
+ " </tr>\n",
1435
+ " <tr>\n",
1436
+ " <td>2</td>\n",
1437
+ " <td>0.357900</td>\n",
1438
+ " <td>0.638603</td>\n",
1439
+ " <td>0.463981</td>\n",
1440
+ " </tr>\n",
1441
+ " </tbody>\n",
1442
+ "</table><p>"
1443
+ ],
1444
+ "text/plain": [
1445
+ "<IPython.core.display.HTML object>"
1446
+ ]
1447
+ },
1448
+ "metadata": {},
1449
+ "output_type": "display_data"
1450
+ },
1451
+ {
1452
+ "name": "stderr",
1453
+ "output_type": "stream",
1454
+ "text": [
1455
+ "[I 2024-03-27 11:15:33,685] Trial 8 finished with value: 0.46398061315082145 and parameters: {'learning_rate': 4.810569035434538e-05, 'num_train_epochs': 2, 'seed': 11, 'per_device_train_batch_size': 8}. Best is trial 3 with value: 0.5025517897100551.\n",
1456
+ "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1457
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1458
+ " warnings.warn(\n",
1459
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1460
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1461
+ ]
1462
+ },
1463
+ {
1464
+ "data": {
1465
+ "text/html": [
1466
+ "\n",
1467
+ " <div>\n",
1468
+ " \n",
1469
+ " <progress value='268' max='804' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1470
+ " [268/804 00:09 < 00:20, 26.67 it/s, Epoch 1/3]\n",
1471
+ " </div>\n",
1472
+ " <table border=\"1\" class=\"dataframe\">\n",
1473
+ " <thead>\n",
1474
+ " <tr style=\"text-align: left;\">\n",
1475
+ " <th>Epoch</th>\n",
1476
+ " <th>Training Loss</th>\n",
1477
+ " <th>Validation Loss</th>\n",
1478
+ " <th>Matthews Correlation</th>\n",
1479
+ " </tr>\n",
1480
+ " </thead>\n",
1481
+ " <tbody>\n",
1482
+ " <tr>\n",
1483
+ " <td>1</td>\n",
1484
+ " <td>No log</td>\n",
1485
+ " <td>0.571560</td>\n",
1486
+ " <td>0.046356</td>\n",
1487
+ " </tr>\n",
1488
+ " </tbody>\n",
1489
+ "</table><p>"
1490
+ ],
1491
+ "text/plain": [
1492
+ "<IPython.core.display.HTML object>"
1493
+ ]
1494
+ },
1495
+ "metadata": {},
1496
+ "output_type": "display_data"
1497
+ },
1498
+ {
1499
+ "name": "stderr",
1500
+ "output_type": "stream",
1501
+ "text": [
1502
+ "[I 2024-03-27 11:15:44,118] Trial 9 pruned. \n"
1503
+ ]
1504
+ }
1505
+ ],
1506
+ "source": [
1507
+ "best_run = trainer.hyperparameter_search(n_trials=10, direction=\"maximize\")"
1508
+ ]
1509
+ },
1510
+ {
1511
+ "cell_type": "code",
1512
+ "execution_count": 27,
1513
+ "id": "ce0ebef8-3a96-4401-a62b-1771b2a68b24",
1514
+ "metadata": {},
1515
+ "outputs": [
1516
+ {
1517
+ "data": {
1518
+ "text/plain": [
1519
+ "BestRun(run_id='3', objective=0.5025517897100551, hyperparameters={'learning_rate': 9.440074279431108e-05, 'num_train_epochs': 2, 'seed': 17, 'per_device_train_batch_size': 32}, run_summary=None)"
1520
+ ]
1521
+ },
1522
+ "execution_count": 27,
1523
+ "metadata": {},
1524
+ "output_type": "execute_result"
1525
+ }
1526
+ ],
1527
+ "source": [
1528
+ "best_run"
1529
+ ]
1530
+ },
1531
+ {
1532
+ "cell_type": "code",
1533
+ "execution_count": 28,
1534
+ "id": "efba4c29-56d3-459f-836e-ead6ec4c179f",
1535
+ "metadata": {},
1536
+ "outputs": [
1537
+ {
1538
+ "name": "stderr",
1539
+ "output_type": "stream",
1540
+ "text": [
1541
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
1542
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1543
+ ]
1544
+ },
1545
+ {
1546
+ "data": {
1547
+ "text/html": [
1548
+ "\n",
1549
+ " <div>\n",
1550
+ " \n",
1551
+ " <progress value='536' max='536' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1552
+ " [536/536 00:21, Epoch 2/2]\n",
1553
+ " </div>\n",
1554
+ " <table border=\"1\" class=\"dataframe\">\n",
1555
+ " <thead>\n",
1556
+ " <tr style=\"text-align: left;\">\n",
1557
+ " <th>Epoch</th>\n",
1558
+ " <th>Training Loss</th>\n",
1559
+ " <th>Validation Loss</th>\n",
1560
+ " <th>Matthews Correlation</th>\n",
1561
+ " </tr>\n",
1562
+ " </thead>\n",
1563
+ " <tbody>\n",
1564
+ " <tr>\n",
1565
+ " <td>1</td>\n",
1566
+ " <td>No log</td>\n",
1567
+ " <td>0.479286</td>\n",
1568
+ " <td>0.436850</td>\n",
1569
+ " </tr>\n",
1570
+ " <tr>\n",
1571
+ " <td>2</td>\n",
1572
+ " <td>0.414800</td>\n",
1573
+ " <td>0.520329</td>\n",
1574
+ " <td>0.502552</td>\n",
1575
+ " </tr>\n",
1576
+ " </tbody>\n",
1577
+ "</table><p>"
1578
+ ],
1579
+ "text/plain": [
1580
+ "<IPython.core.display.HTML object>"
1581
+ ]
1582
+ },
1583
+ "metadata": {},
1584
+ "output_type": "display_data"
1585
+ },
1586
+ {
1587
+ "data": {
1588
+ "text/plain": [
1589
+ "TrainOutput(global_step=536, training_loss=0.40565217964684785, metrics={'train_runtime': 21.0572, 'train_samples_per_second': 812.168, 'train_steps_per_second': 25.454, 'total_flos': 153655196855484.0, 'train_loss': 0.40565217964684785, 'epoch': 2.0})"
1590
+ ]
1591
+ },
1592
+ "execution_count": 28,
1593
+ "metadata": {},
1594
+ "output_type": "execute_result"
1595
+ }
1596
+ ],
1597
+ "source": [
1598
+ "for n,v in best_run.hyperparameters.items():\n",
1599
+ " setattr(trainer.args, n, v)\n",
1600
+ "\n",
1601
+ "trainer.train()"
1602
+ ]
1603
+ },
1604
+ {
1605
+ "cell_type": "code",
1606
+ "execution_count": null,
1607
+ "id": "06baa2a0-6d79-4e2e-ad8e-d67ec1ed8c57",
1608
+ "metadata": {},
1609
+ "outputs": [],
1610
+ "source": []
1611
+ }
1612
+ ],
1613
+ "metadata": {
1614
+ "kernelspec": {
1615
+ "display_name": "Python 3 (ipykernel)",
1616
+ "language": "python",
1617
+ "name": "python3"
1618
+ },
1619
+ "language_info": {
1620
+ "codemirror_mode": {
1621
+ "name": "ipython",
1622
+ "version": 3
1623
+ },
1624
+ "file_extension": ".py",
1625
+ "mimetype": "text/x-python",
1626
+ "name": "python",
1627
+ "nbconvert_exporter": "python",
1628
+ "pygments_lexer": "ipython3",
1629
+ "version": "3.10.12"
1630
+ }
1631
+ },
1632
+ "nbformat": 4,
1633
+ "nbformat_minor": 5
1634
+ }