yingbei commited on
Commit
91fc1b6
1 Parent(s): 2c4bdd4

Upload fine_tuning_tutorial_jax.ipynb

Browse files
Files changed (1) hide show
  1. fine_tuning_tutorial_jax.ipynb +1421 -0
fine_tuning_tutorial_jax.ipynb ADDED
@@ -0,0 +1,1421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "OiBSu3YkEcoX"
7
+ },
8
+ "source": [
9
+ "Copyright 2024 DeepMind Technologies Limited.\n",
10
+ "\n",
11
+ "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
12
+ "\n",
13
+ "http://www.apache.org/licenses/LICENSE-2.0\n",
14
+ "\n",
15
+ "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {
21
+ "id": "Y5OeTiryEcoX"
22
+ },
23
+ "source": [
24
+ "# Fine-tuning the 2B Griffin model with Flax\n",
25
+ "\n",
26
+ "In this tutorial you will learn how to fine-tune the 2B Griffin model for a simple translation task."
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "metadata": {
32
+ "id": "5m81VQOqEcoX"
33
+ },
34
+ "source": [
35
+ "## Setup"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 1,
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Cloning into 'recurrentgemma'...\n",
48
+ "remote: Enumerating objects: 52, done.\u001b[K\n",
49
+ "remote: Counting objects: 100% (49/49), done.\u001b[K\n",
50
+ "remote: Compressing objects: 100% (47/47), done.\u001b[K\n",
51
+ "remote: Total 52 (delta 16), reused 5 (delta 2), pack-reused 3\u001b[K\n",
52
+ "Receiving objects: 100% (52/52), 74.57 KiB | 1.01 MiB/s, done.\n",
53
+ "Resolving deltas: 100% (16/16), done.\n"
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "!git clone https://github.com/google-deepmind/recurrentgemma.git"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 7,
64
+ "metadata": {
65
+ "cellView": "form",
66
+ "id": "XpSw-_4EEcoY"
67
+ },
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "\u001b[33mDEPRECATION: git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[jax] contains an egg fragment with a non-PEP 508 name pip 25.0 will enforce this behaviour change. A possible replacement is to use the req @ url syntax, and remove the egg fragment. Discussion can be found at https://github.com/pypa/pip/issues/11617\u001b[0m\u001b[33m\n",
74
+ "\u001b[0mCollecting recurrentgemma[jax]\n",
75
+ " Cloning https://github.com/google-deepmind/recurrentgemma.git to /private/var/folders/jx/gld2clwj7sd_q8hd2m6hztcr0000gn/T/pip-install-2c9hrit5/recurrentgemma_54f0084d6e164dc38004db09c24dfacb\n",
76
+ " Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/recurrentgemma.git /private/var/folders/jx/gld2clwj7sd_q8hd2m6hztcr0000gn/T/pip-install-2c9hrit5/recurrentgemma_54f0084d6e164dc38004db09c24dfacb\n",
77
+ " Resolved https://github.com/google-deepmind/recurrentgemma.git to commit 0f5ca57442f17c7309c70b0228fd8e5505cbdaa1\n",
78
+ " Installing build dependencies ... \u001b[?25ldone\n",
79
+ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
80
+ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
81
+ "\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.21 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from recurrentgemma[jax]) (1.24.4)\n",
82
+ "Requirement already satisfied: einops<0.8.0,>=0.7.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from recurrentgemma[jax]) (0.7.0)\n",
83
+ "Collecting jaxtyping<0.3.0,>=0.2.28\n",
84
+ " Downloading jaxtyping-0.2.28-py3-none-any.whl (40 kB)\n",
85
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.7/40.7 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
86
+ "\u001b[?25hCollecting absl-py<1.5.0,>=1.4.0\n",
87
+ " Downloading absl_py-1.4.0-py3-none-any.whl (126 kB)\n",
88
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.5/126.5 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
89
+ "\u001b[?25hCollecting sentencepiece<0.3.0,>=0.2.0\n",
90
+ " Downloading sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl (1.2 MB)\n",
91
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m27.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n",
92
+ "\u001b[?25hCollecting orbax-checkpoint==0.5.7\n",
93
+ " Downloading orbax_checkpoint-0.5.7-py3-none-any.whl (159 kB)\n",
94
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m159.2/159.2 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
95
+ "\u001b[?25hCollecting jax<0.5.0,>=0.4.23\n",
96
+ " Downloading jax-0.4.26-py3-none-any.whl (1.9 MB)\n",
97
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.9/1.9 MB\u001b[0m \u001b[31m31.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
98
+ "\u001b[?25hCollecting flax<0.9.0,>=0.8.2\n",
99
+ " Downloading flax-0.8.2-py3-none-any.whl (686 kB)\n",
100
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m686.8/686.8 kB\u001b[0m \u001b[31m43.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
101
+ "\u001b[?25hCollecting etils[epath,epy]\n",
102
+ " Downloading etils-1.7.0-py3-none-any.whl (152 kB)\n",
103
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m152.4/152.4 kB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
104
+ "\u001b[?25hRequirement already satisfied: typing_extensions in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from orbax-checkpoint==0.5.7->recurrentgemma[jax]) (4.9.0)\n",
105
+ "Requirement already satisfied: pyyaml in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from orbax-checkpoint==0.5.7->recurrentgemma[jax]) (6.0.1)\n",
106
+ "Collecting tensorstore>=0.1.51\n",
107
+ " Downloading tensorstore-0.1.56-cp310-cp310-macosx_11_0_arm64.whl (13.0 MB)\n",
108
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.0/13.0 MB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
109
+ "\u001b[?25hCollecting msgpack\n",
110
+ " Downloading msgpack-1.0.8-cp310-cp310-macosx_11_0_arm64.whl (84 kB)\n",
111
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.9/84.9 kB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
112
+ "\u001b[?25hCollecting jaxlib\n",
113
+ " Downloading jaxlib-0.4.26-cp310-cp310-macosx_11_0_arm64.whl (66.7 MB)\n",
114
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.7/66.7 MB\u001b[0m \u001b[31m32.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
115
+ "\u001b[?25hRequirement already satisfied: nest_asyncio in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from orbax-checkpoint==0.5.7->recurrentgemma[jax]) (1.6.0)\n",
116
+ "Requirement already satisfied: protobuf in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from orbax-checkpoint==0.5.7->recurrentgemma[jax]) (4.25.2)\n",
117
+ "Collecting optax\n",
118
+ " Downloading optax-0.2.2-py3-none-any.whl (223 kB)\n",
119
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m223.7/223.7 kB\u001b[0m \u001b[31m29.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
120
+ "\u001b[?25hRequirement already satisfied: rich>=11.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from flax<0.9.0,>=0.8.2->recurrentgemma[jax]) (13.7.1)\n",
121
+ "Requirement already satisfied: scipy>=1.9 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from jax<0.5.0,>=0.4.23->recurrentgemma[jax]) (1.12.0)\n",
122
+ "Collecting ml-dtypes>=0.2.0\n",
123
+ " Downloading ml_dtypes-0.4.0-cp310-cp310-macosx_10_9_universal2.whl (390 kB)\n",
124
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m390.9/390.9 kB\u001b[0m \u001b[31m29.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
125
+ "\u001b[?25hCollecting opt-einsum\n",
126
+ " Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n",
127
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.5/65.5 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
128
+ "\u001b[?25hCollecting typeguard==2.13.3\n",
129
+ " Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)\n",
130
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from rich>=11.1->flax<0.9.0,>=0.8.2->recurrentgemma[jax]) (3.0.0)\n",
131
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from rich>=11.1->flax<0.9.0,>=0.8.2->recurrentgemma[jax]) (2.17.2)\n",
132
+ "Collecting zipp\n",
133
+ " Downloading zipp-3.18.1-py3-none-any.whl (8.2 kB)\n",
134
+ "Requirement already satisfied: fsspec in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint==0.5.7->recurrentgemma[jax]) (2023.10.0)\n",
135
+ "Requirement already satisfied: importlib_resources in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint==0.5.7->recurrentgemma[jax]) (6.1.2)\n",
136
+ "Collecting chex>=0.1.86\n",
137
+ " Downloading chex-0.1.86-py3-none-any.whl (98 kB)\n",
138
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.2/98.2 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
139
+ "\u001b[?25hRequirement already satisfied: toolz>=0.9.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from chex>=0.1.86->optax->flax<0.9.0,>=0.8.2->recurrentgemma[jax]) (0.12.1)\n",
140
+ "Requirement already satisfied: mdurl~=0.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax<0.9.0,>=0.8.2->recurrentgemma[jax]) (0.1.2)\n",
141
+ "Building wheels for collected packages: recurrentgemma\n",
142
+ " Building wheel for recurrentgemma (pyproject.toml) ... \u001b[?25ldone\n",
143
+ "\u001b[?25h Created wheel for recurrentgemma: filename=recurrentgemma-0.1.0-py3-none-any.whl size=73483 sha256=fb0155d9d3fe031716dcb26e7c11b10a02f545879b13d6f5286eb200ec90cd86\n",
144
+ " Stored in directory: /private/var/folders/jx/gld2clwj7sd_q8hd2m6hztcr0000gn/T/pip-ephem-wheel-cache-62nk7qne/wheels/31/37/18/c57f1df6091b661385ab728b959bdfbf2078d9fc7c856899e4\n",
145
+ "Successfully built recurrentgemma\n",
146
+ "Installing collected packages: sentencepiece, zipp, typeguard, opt-einsum, msgpack, ml-dtypes, etils, absl-py, tensorstore, jaxtyping, jaxlib, jax, recurrentgemma, chex, orbax-checkpoint, optax, flax\n",
147
+ " Attempting uninstall: sentencepiece\n",
148
+ " Found existing installation: sentencepiece 0.1.99\n",
149
+ " Uninstalling sentencepiece-0.1.99:\n",
150
+ " Successfully uninstalled sentencepiece-0.1.99\n",
151
+ " Attempting uninstall: absl-py\n",
152
+ " Found existing installation: absl-py 2.1.0\n",
153
+ " Uninstalling absl-py-2.1.0:\n",
154
+ " Successfully uninstalled absl-py-2.1.0\n",
155
+ "Successfully installed absl-py-1.4.0 chex-0.1.86 etils-1.7.0 flax-0.8.2 jax-0.4.26 jaxlib-0.4.26 jaxtyping-0.2.28 ml-dtypes-0.4.0 msgpack-1.0.8 opt-einsum-3.3.0 optax-0.2.2 orbax-checkpoint-0.5.7 recurrentgemma-0.1.0 sentencepiece-0.2.0 tensorstore-0.1.56 typeguard-2.13.3 zipp-3.18.1\n",
156
+ "\n",
157
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
158
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
159
+ "\u001b[31mERROR: Could not find a version that satisfies the requirement tensorflow-cpu (from versions: none)\u001b[0m\u001b[31m\n",
160
+ "\u001b[0m\u001b[31mERROR: No matching distribution found for tensorflow-cpu\u001b[0m\u001b[31m\n",
161
+ "\u001b[0m\n",
162
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
163
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
164
+ "\u001b[31mERROR: Can not perform a '--user' install. User site-packages are not visible in this virtualenv.\u001b[0m\u001b[31m\n",
165
+ "\u001b[0m\n",
166
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
167
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
168
+ "Requirement already satisfied: datasets in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (2.16.1)\n",
169
+ "Requirement already satisfied: pyarrow-hotfix in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (0.6)\n",
170
+ "Requirement already satisfied: xxhash in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
171
+ "Requirement already satisfied: requests>=2.19.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (2.31.0)\n",
172
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (2023.10.0)\n",
173
+ "Requirement already satisfied: numpy>=1.17 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (1.24.4)\n",
174
+ "Requirement already satisfied: pandas in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (2.2.0)\n",
175
+ "Requirement already satisfied: multiprocess in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (0.70.15)\n",
176
+ "Requirement already satisfied: packaging in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (23.2)\n",
177
+ "Requirement already satisfied: pyyaml>=5.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (6.0.1)\n",
178
+ "Requirement already satisfied: huggingface-hub>=0.19.4 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (0.20.3)\n",
179
+ "Requirement already satisfied: filelock in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (3.13.1)\n",
180
+ "Requirement already satisfied: tqdm>=4.62.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (4.66.1)\n",
181
+ "Requirement already satisfied: pyarrow>=8.0.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (15.0.0)\n",
182
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (0.3.7)\n",
183
+ "Requirement already satisfied: aiohttp in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from datasets) (3.9.1)\n",
184
+ "Requirement already satisfied: attrs>=17.3.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n",
185
+ "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
186
+ "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n",
187
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n",
188
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n",
189
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n",
190
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from huggingface-hub>=0.19.4->datasets) (4.9.0)\n",
191
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
192
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.6)\n",
193
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
194
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2.2.0)\n",
195
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
196
+ "Requirement already satisfied: tzdata>=2022.7 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from pandas->datasets) (2023.4)\n",
197
+ "Requirement already satisfied: pytz>=2020.1 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from pandas->datasets) (2023.4)\n",
198
+ "Requirement already satisfied: six>=1.5 in /Users/tybalex/.pyenv/versions/3.10.12/envs/new3102/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
199
+ "\n",
200
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
201
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "# @title Installation\n",
207
+ "! pip install 'git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[jax]'\n",
208
+ "! pip install tensorflow-cpu # Might require a session restart\n",
209
+ "! pip install --user kaggle\n",
210
+ "! pip install datasets"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": 10,
216
+ "metadata": {
217
+ "id": "yWaP_LPoEcoY"
218
+ },
219
+ "outputs": [
220
+ {
221
+ "ename": "ModuleNotFoundError",
222
+ "evalue": "No module named 'tensorflow'",
223
+ "output_type": "error",
224
+ "traceback": [
225
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
226
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
227
+ "Cell \u001b[0;32mIn[10], line 20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrecurrentgemma\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m jax \u001b[38;5;28;01mas\u001b[39;00m recurrentgemma\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# We will use tensorflow to handle the dataset\u001b[39;00m\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtf\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow_datasets\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtfds\u001b[39;00m\n",
228
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tensorflow'"
229
+ ]
230
+ }
231
+ ],
232
+ "source": [
233
+ "# @title Python imports\n",
234
+ "import pathlib\n",
235
+ "from typing import Any, Mapping, Iterator\n",
236
+ "import enum\n",
237
+ "import functools\n",
238
+ "\n",
239
+ "# We import JAX and some related packages.\n",
240
+ "import chex\n",
241
+ "import jax\n",
242
+ "import jax.numpy as jnp\n",
243
+ "import optax\n",
244
+ "\n",
245
+ "\n",
246
+ "\n",
247
+ "# Finally, we import Recurrentgemma.\n",
248
+ "import sentencepiece as spm\n",
249
+ "from recurrentgemma import jax as recurrentgemma\n",
250
+ "\n",
251
+ "# We will use tensorflow to handle the dataset\n",
252
+ "import tensorflow as tf\n",
253
+ "import tensorflow_datasets as tfds"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {},
259
+ "source": []
260
+ },
261
+ {
262
+ "cell_type": "markdown",
263
+ "metadata": {
264
+ "id": "iLafhtv3Rg5F"
265
+ },
266
+ "source": [
267
+ "### Downloading the checkpoint\n",
268
+ "\n",
269
+ "To use Griffin's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n",
270
+ "\n",
271
+ "1. Visit https://www.kaggle.com/ and create an account.\n",
272
+ "2. Go to your account settings, then the 'API' section.\n",
273
+ "3. Click 'Create new token' to download your key.\n",
274
+ "\n",
275
+ "You will also need to acknowledge the Terms and Conditions of the RecrurrentGemma models on https://www.kaggle.com/models/google/recurrentgemma/ in order to be able to download the model weights and the tokenizer.\n",
276
+ "\n",
277
+ "Then run the cell below."
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "markdown",
282
+ "metadata": {
283
+ "id": "jCZSmEVDVv6O"
284
+ },
285
+ "source": [
286
+ "If everything went well, you should see:\n",
287
+ "```\n",
288
+ "Kaggle credentials set.\n",
289
+ "Kaggle credentials successfully validated.\n",
290
+ "```\n",
291
+ "\n",
292
+ "Now select and download the checkpoint you want to try. The 2b model can fit in memory for fine-tuning."
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "markdown",
297
+ "metadata": {
298
+ "id": "DVgmx04E2ztl"
299
+ },
300
+ "source": [
301
+ "Need to visit the kaggle page and agree to their term."
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 11,
307
+ "metadata": {
308
+ "id": "RoUb7Shg-bex"
309
+ },
310
+ "outputs": [
311
+ {
312
+ "name": "stdout",
313
+ "output_type": "stream",
314
+ "text": [
315
+ "fatal: destination path 'recurrentg-2b-it' already exists and is not an empty directory.\n"
316
+ ]
317
+ },
318
+ {
319
+ "name": "stderr",
320
+ "output_type": "stream",
321
+ "text": [
322
+ "/Users/tybalex/.pyenv/versions/3.10.12/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
323
+ " pid, fd = os.forkpty()\n"
324
+ ]
325
+ }
326
+ ],
327
+ "source": [
328
+ "!git clone https://huggingface.co/yingbei/recurrentg-2b-it\n"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 13,
334
+ "metadata": {
335
+ "id": "1TOdNwcNBhno"
336
+ },
337
+ "outputs": [],
338
+ "source": [
339
+ "VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}\n",
340
+ "weights_dir = pathlib.Path(\"./recurrentg-2b-it\")\n",
341
+ "ckpt_path = weights_dir / VARIANT\n",
342
+ "vocab_path = weights_dir / 'tokenizer.model'"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {
348
+ "id": "ejQhgtjbEcoY"
349
+ },
350
+ "source": [
351
+ "## Step 1: prepare the dataset\n",
352
+ "\n"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "metadata": {
359
+ "id": "XeynYJXCEymJ"
360
+ },
361
+ "outputs": [],
362
+ "source": [
363
+ "from datasets import load_dataset\n",
364
+ "code_sharegpt = load_dataset(\"sanjay920/code74k-sharegpt\")"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {
371
+ "id": "yDhp3v7DFSUd"
372
+ },
373
+ "outputs": [],
374
+ "source": [
375
+ "code_sharegpt[\"train\"][0][\"conversations\"]"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "metadata": {
382
+ "id": "jOMGn19rG5JE"
383
+ },
384
+ "outputs": [],
385
+ "source": [
386
+ "import json\n",
387
+ "chat_prefix = \"<start_of_turn>\"\n",
388
+ "chat_suffix = \"<end_of_turn>\"\n",
389
+ "user_role = \"user\\n\"\n",
390
+ "preprocessed_code_sharegpt_data = []\n",
391
+ "for itor in code_sharegpt[\"train\"]:\n",
392
+ " c = itor[\"conversations\"]\n",
393
+ " c = json.loads(c)\n",
394
+ " assert c[-1][\"from\"] == \"gpt\"\n",
395
+ " assert c[0][\"from\"] == \"human\"\n",
396
+ " assert len(c) == 2\n",
397
+ " input = chat_prefix + user_role + c[0][\"value\"] + chat_suffix\n",
398
+ " output = c[1][\"value\"]\n",
399
+ " preprocessed_code_sharegpt_data.append({\"input\": input, \"output\": output})\n",
400
+ "\n",
401
+ "print(json.dumps(preprocessed_code_sharegpt_data[0], indent=4))\n",
402
+ "print(len(preprocessed_code_sharegpt_data))\n"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "metadata": {
409
+ "id": "oZSVAbmWVD1q"
410
+ },
411
+ "outputs": [],
412
+ "source": [
413
+ "\n",
414
+ "def load_custom_data(data):\n",
415
+ " # convert list of dicts to tfds dataset format\n",
416
+ " def preprocess(item):\n",
417
+ " # Convert your item here, e.g., tokenize text\n",
418
+ " return {\n",
419
+ " 'src': item['input'], # Assume these are already preprocessed\n",
420
+ " 'dst': item['output'],\n",
421
+ " }\n",
422
+ "\n",
423
+ " # Create a Dataset from the list of dictionaries\n",
424
+ " ds = tf.data.Dataset.from_generator(lambda: (preprocess(item) for item in data),\n",
425
+ " output_types={'src': tf.string, 'dst': tf.string})\n",
426
+ "\n",
427
+ " # Further dataset operations (batching, padding, etc.) go here\n",
428
+ " # For example, to batch:\n",
429
+ " # ds = ds.batch(2)\n",
430
+ "\n",
431
+ " return ds"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "markdown",
436
+ "metadata": {
437
+ "id": "NYC42hJgEcoY"
438
+ },
439
+ "source": [
440
+ "### Tokenizer\n",
441
+ "\n",
442
+ "Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": null,
448
+ "metadata": {
449
+ "cellView": "form",
450
+ "id": "TpyG5YW1EcoY"
451
+ },
452
+ "outputs": [],
453
+ "source": [
454
+ "vocab = spm.SentencePieceProcessor()\n",
455
+ "vocab.Load(str(vocab_path))"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "metadata": {
461
+ "id": "Ab2MSf-qEcoY"
462
+ },
463
+ "source": [
464
+ "Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Griffin 2B model, we need a few adjustments:\n",
465
+ "\n",
466
+ "- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.\n",
467
+ "\n",
468
+ "- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.\n",
469
+ "\n",
470
+ "- **LM Tokens**: Griffin models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example."
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {
477
+ "cellView": "form",
478
+ "id": "L9cjK0uxEcoY"
479
+ },
480
+ "outputs": [],
481
+ "source": [
482
+ "class GriffinTokenizer:\n",
483
+ " \"\"\"Custom wrapper around a SentencePieceProcessor for tensorflow.\"\"\"\n",
484
+ "\n",
485
+ " def __init__(self, spm_processor: spm.SentencePieceProcessor):\n",
486
+ " self._spm_processor = spm_processor\n",
487
+ "\n",
488
+ " @property\n",
489
+ " def pad_id(self) -> int:\n",
490
+ " \"\"\"Fast access to the pad id.\"\"\"\n",
491
+ " return self._spm_processor.pad_id()\n",
492
+ "\n",
493
+ " def tokenize(\n",
494
+ " self,\n",
495
+ " example: str | bytes,\n",
496
+ " prefix: str = '',\n",
497
+ " suffix: str = '',\n",
498
+ " add_eos: bool = True,\n",
499
+ " ) -> jax.Array:\n",
500
+ " \"\"\"\n",
501
+ " Tokenization function.\n",
502
+ "\n",
503
+ " Args:\n",
504
+ " example: input string to tokenize.\n",
505
+ " prefix: prefix to add to the input string.\n",
506
+ " suffix: suffix to add to the input string.\n",
507
+ " add_eos: if True, add an end of sentence token at the end of the output\n",
508
+ " sequence.\n",
509
+ " Returns:\n",
510
+ " Tokens corresponding to the input string.\n",
511
+ " \"\"\"\n",
512
+ " int_list = [self._spm_processor.bos_id()]\n",
513
+ " int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n",
514
+ " if add_eos:\n",
515
+ " int_list.append(self._spm_processor.eos_id())\n",
516
+ "\n",
517
+ " return jnp.array(int_list, dtype=jnp.int32)\n",
518
+ "\n",
519
+ " def tokenize_tf_op(\n",
520
+ " self,\n",
521
+ " str_tensor: tf.Tensor,\n",
522
+ " prefix: str = '',\n",
523
+ " suffix: str = '',\n",
524
+ " add_eos: bool = True,\n",
525
+ " ) -> tf.Tensor:\n",
526
+ " \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n",
527
+ " encoded = tf.numpy_function(\n",
528
+ " self.tokenize,\n",
529
+ " [str_tensor, prefix, suffix, add_eos],\n",
530
+ " tf.int32)\n",
531
+ " encoded.set_shape([None])\n",
532
+ " return encoded\n",
533
+ "\n",
534
+ " def to_string(self, tokens: jax.Array) -> str:\n",
535
+ " \"\"\"Convert an array of tokens to a string.\"\"\"\n",
536
+ " return self._spm_processor.EncodeIds(tokens.tolist())"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "markdown",
541
+ "metadata": {
542
+ "id": "6xuCVkurEcoY"
543
+ },
544
+ "source": [
545
+ "Now let's try our custom tokenizer on the MTNT dataset"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": null,
551
+ "metadata": {
552
+ "cellView": "form",
553
+ "id": "xEA-97ioEcoY"
554
+ },
555
+ "outputs": [],
556
+ "source": [
557
+ "def tokenize_source(tokenizer, example: tf.Tensor):\n",
558
+ " return tokenizer.tokenize_tf_op(\n",
559
+ " example,\n",
560
+ " prefix='',\n",
561
+ " suffix='\\n<start_of_turn>model\\n',\n",
562
+ " add_eos=False\n",
563
+ " )\n",
564
+ "def tokenize_destination(tokenizer, example: tf.Tensor):\n",
565
+ " return tokenizer.tokenize_tf_op(example, add_eos=True)\n",
566
+ "\n",
567
+ "tokenizer = GriffinTokenizer(vocab)\n",
568
+ "# ds = tfds.load(\"mtnt/en-fr\",split=\"train\")\n",
569
+ "\n",
570
+ "# ds = ds.take(2)\n",
571
+ "# for d in ds:\n",
572
+ "# print(d)\n",
573
+ "\n",
574
+ "ds = load_custom_data(preprocessed_code_sharegpt_data[:2])\n",
575
+ "print(ds)\n",
576
+ "ds = ds.map(lambda x: {\n",
577
+ " 'input': tokenize_source(tokenizer, x['src']),\n",
578
+ " 'output': tokenize_destination(tokenizer, x['dst'])\n",
579
+ " })\n",
580
+ "ds = ds.as_numpy_iterator()\n",
581
+ "for idx, example in enumerate(ds):\n",
582
+ " print(f'Example {idx}:')\n",
583
+ " for key, val in example.items():\n",
584
+ " print(f'{key}: {val}')\n",
585
+ " print()"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "markdown",
590
+ "metadata": {
591
+ "id": "r-x0aTugEcoY"
592
+ },
593
+ "source": [
594
+ "### Data loader\n",
595
+ "\n",
596
+ "We can now wrap everything a build our data loader."
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "metadata": {
603
+ "cellView": "form",
604
+ "id": "XwFFs2mDEcoY"
605
+ },
606
+ "outputs": [],
607
+ "source": [
608
+ "@chex.dataclass(frozen=True)\n",
609
+ "class TrainingInput:\n",
610
+ " # Input tokens given to the model\n",
611
+ " input_tokens: jax.Array\n",
612
+ "\n",
613
+ " # A mask that determines which tokens contribute to the target loss\n",
614
+ " # calculation.\n",
615
+ " target_mask: jax.Array\n",
616
+ "\n",
617
+ "class DatasetSplit(enum.Enum):\n",
618
+ " TRAIN = 'train'\n",
619
+ " VALIDATION = 'valid'\n",
620
+ "\n",
621
+ "\n",
622
+ "class MyDatasetBuilder:\n",
623
+ " \"\"\"Data loader for the MTNT dataset.\"\"\"\n",
624
+ "\n",
625
+ " N_ITEMS = {DatasetSplit.TRAIN: 2000, DatasetSplit.VALIDATION: 100}\n",
626
+ "\n",
627
+ " BUFFER_SIZE_SHUFFLE = 1000\n",
628
+ " TRANSLATION_PREFIX = ''\n",
629
+ " TRANSLATION_SUFFIX = '\\n<start_of_turn>model\\n'\n",
630
+ "\n",
631
+ " def __init__(self,\n",
632
+ " tokenizer : GriffinTokenizer,\n",
633
+ " max_seq_len: int):\n",
634
+ " \"\"\"Constructor.\n",
635
+ "\n",
636
+ " Args:\n",
637
+ " tokenizer: Gemma tokenizer to use.\n",
638
+ " max_seq_len: size of each sequence in a given batch.\n",
639
+ " \"\"\"\n",
640
+ " self._tokenizer = tokenizer\n",
641
+ " self._base_data = {\n",
642
+ " DatasetSplit.TRAIN: load_custom_data(preprocessed_code_sharegpt_data[:2000]),\n",
643
+ " DatasetSplit.VALIDATION: load_custom_data(preprocessed_code_sharegpt_data[-100:]),\n",
644
+ " }\n",
645
+ " self._max_seq_len = max_seq_len\n",
646
+ "\n",
647
+ " def _tokenize_source(self, example: tf.Tensor):\n",
648
+ " \"\"\"Tokenization function for the source.\"\"\"\n",
649
+ " return self._tokenizer.tokenize_tf_op(\n",
650
+ " example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,\n",
651
+ " add_eos=False\n",
652
+ " )\n",
653
+ "\n",
654
+ " def _tokenize_destination(self, example: tf.Tensor):\n",
655
+ " \"\"\"Tokenization function for the French translation.\"\"\"\n",
656
+ " return self._tokenizer.tokenize_tf_op(example, add_eos=True)\n",
657
+ "\n",
658
+ " def _pad_up_to_max_len(self,\n",
659
+ " input_tensor: tf.Tensor,\n",
660
+ " pad_value: int | bool,\n",
661
+ " ) -> tf.Tensor:\n",
662
+ " \"\"\"Pad the given tensor up to sequence length of a batch.\"\"\"\n",
663
+ " seq_len = tf.shape(input_tensor)[0]\n",
664
+ " to_pad = tf.maximum(self._max_seq_len - seq_len, 0)\n",
665
+ " return tf.pad(\n",
666
+ " input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,\n",
667
+ " )\n",
668
+ "\n",
669
+ " def _to_training_input(\n",
670
+ " self,\n",
671
+ " src_tokens: jax.Array,\n",
672
+ " dst_tokens: jax.Array,\n",
673
+ " ) -> TrainingInput:\n",
674
+ " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n",
675
+ "\n",
676
+ " # The input sequence fed to the model is simply the concatenation of the\n",
677
+ " # source and the destination.\n",
678
+ " tokens = tf.concat([src_tokens, dst_tokens], axis=0)\n",
679
+ "\n",
680
+ " # We want to prevent the model from updating based on the source (input)\n",
681
+ " # tokens. To achieve this, we add a target mask to each input.\n",
682
+ " q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)\n",
683
+ " a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)\n",
684
+ " mask = tf.concat([q_mask, a_mask], axis=0)\n",
685
+ "\n",
686
+ " # If the output tokens sequence is smaller than the target sequence size,\n",
687
+ " # then we pad it with pad tokens.\n",
688
+ " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n",
689
+ "\n",
690
+ " # We don't want to perform the backward on the pad tokens.\n",
691
+ " mask = self._pad_up_to_max_len(mask, False)\n",
692
+ "\n",
693
+ " return TrainingInput(input_tokens=tokens, target_mask=mask)\n",
694
+ "\n",
695
+ "\n",
696
+ " def get_train_dataset(self, batch_size: int, num_epochs: int):\n",
697
+ " \"\"\"Build the training dataset.\"\"\"\n",
698
+ "\n",
699
+ " # Tokenize each sample\n",
700
+ " ds = self._base_data[DatasetSplit.TRAIN].map(\n",
701
+ " lambda x : (self._tokenize_source(x['src']),\n",
702
+ " self._tokenize_destination(x['dst']))\n",
703
+ " )\n",
704
+ " print(ds)\n",
705
+ "\n",
706
+ " # Convert them to training inputs\n",
707
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
708
+ "\n",
709
+ " # Remove the samples which are too long\n",
710
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
711
+ "\n",
712
+ " # Shuffle the dataset\n",
713
+ " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n",
714
+ "\n",
715
+ " # Repeat if necessary\n",
716
+ " ds = ds.repeat(num_epochs)\n",
717
+ "\n",
718
+ " # Build batches\n",
719
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
720
+ " return ds\n",
721
+ "\n",
722
+ " def get_validation_dataset(self, batch_size: int):\n",
723
+ " \"\"\"Build the validation dataset.\"\"\"\n",
724
+ "\n",
725
+ " # Same as the training dataset, but no shuffling and no repetition\n",
726
+ " ds = self._base_data[DatasetSplit.VALIDATION].map(\n",
727
+ " lambda x : (self._tokenize_source(x['src']),\n",
728
+ " self._tokenize_destination(x['dst']))\n",
729
+ " )\n",
730
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
731
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
732
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
733
+ " return ds"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "markdown",
738
+ "metadata": {
739
+ "id": "m-BHqBGBVlei"
740
+ },
741
+ "source": [
742
+ "# backup dataset class"
743
+ ]
744
+ },
745
+ {
746
+ "cell_type": "code",
747
+ "execution_count": null,
748
+ "metadata": {
749
+ "id": "daHyZFztVkkE"
750
+ },
751
+ "outputs": [],
752
+ "source": [
753
+ "class MTNTDatasetBuilder:\n",
754
+ " \"\"\"Data loader for the MTNT dataset.\"\"\"\n",
755
+ "\n",
756
+ " N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}\n",
757
+ "\n",
758
+ " BUFFER_SIZE_SHUFFLE = 10_000\n",
759
+ " TRANSLATION_PREFIX = 'Translate this into French:\\n'\n",
760
+ " TRANSLATION_SUFFIX = '\\n'\n",
761
+ "\n",
762
+ " def __init__(self,\n",
763
+ " tokenizer : GriffinTokenizer,\n",
764
+ " max_seq_len: int):\n",
765
+ " \"\"\"Constructor.\n",
766
+ "\n",
767
+ " Args:\n",
768
+ " tokenizer: Gemma tokenizer to use.\n",
769
+ " max_seq_len: size of each sequence in a given batch.\n",
770
+ " \"\"\"\n",
771
+ " self._tokenizer = tokenizer\n",
772
+ " self._base_data = {\n",
773
+ " DatasetSplit.TRAIN: tfds.load(\"mtnt/en-fr\",split=\"train\"),\n",
774
+ " DatasetSplit.VALIDATION: tfds.load(\"mtnt/en-fr\",split=\"valid\"),\n",
775
+ " }\n",
776
+ " self._max_seq_len = max_seq_len\n",
777
+ "\n",
778
+ " def _tokenize_source(self, example: tf.Tensor):\n",
779
+ " \"\"\"Tokenization function for the source.\"\"\"\n",
780
+ " return self._tokenizer.tokenize_tf_op(\n",
781
+ " example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,\n",
782
+ " add_eos=False\n",
783
+ " )\n",
784
+ "\n",
785
+ " def _tokenize_destination(self, example: tf.Tensor):\n",
786
+ " \"\"\"Tokenization function for the French translation.\"\"\"\n",
787
+ " return self._tokenizer.tokenize_tf_op(example, add_eos=True)\n",
788
+ "\n",
789
+ " def _pad_up_to_max_len(self,\n",
790
+ " input_tensor: tf.Tensor,\n",
791
+ " pad_value: int | bool,\n",
792
+ " ) -> tf.Tensor:\n",
793
+ " \"\"\"Pad the given tensor up to sequence length of a batch.\"\"\"\n",
794
+ " seq_len = tf.shape(input_tensor)[0]\n",
795
+ " to_pad = tf.maximum(self._max_seq_len - seq_len, 0)\n",
796
+ " return tf.pad(\n",
797
+ " input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,\n",
798
+ " )\n",
799
+ "\n",
800
+ " def _to_training_input(\n",
801
+ " self,\n",
802
+ " src_tokens: jax.Array,\n",
803
+ " dst_tokens: jax.Array,\n",
804
+ " ) -> TrainingInput:\n",
805
+ " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n",
806
+ "\n",
807
+ " # The input sequence fed to the model is simply the concatenation of the\n",
808
+ " # source and the destination.\n",
809
+ " tokens = tf.concat([src_tokens, dst_tokens], axis=0)\n",
810
+ "\n",
811
+ " # We want to prevent the model from updating based on the source (input)\n",
812
+ " # tokens. To achieve this, we add a target mask to each input.\n",
813
+ " q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)\n",
814
+ " a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)\n",
815
+ " mask = tf.concat([q_mask, a_mask], axis=0)\n",
816
+ "\n",
817
+ " # If the output tokens sequence is smaller than the target sequence size,\n",
818
+ " # then we pad it with pad tokens.\n",
819
+ " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n",
820
+ "\n",
821
+ " # We don't want to perform the backward on the pad tokens.\n",
822
+ " mask = self._pad_up_to_max_len(mask, False)\n",
823
+ "\n",
824
+ " return TrainingInput(input_tokens=tokens, target_mask=mask)\n",
825
+ "\n",
826
+ "\n",
827
+ " def get_train_dataset(self, batch_size: int, num_epochs: int):\n",
828
+ " \"\"\"Build the training dataset.\"\"\"\n",
829
+ "\n",
830
+ " # Tokenize each sample\n",
831
+ " ds = self._base_data[DatasetSplit.TRAIN].map(\n",
832
+ " lambda x : (self._tokenize_source(x['src']),\n",
833
+ " self._tokenize_destination(x['dst']))\n",
834
+ " )\n",
835
+ "\n",
836
+ " # Convert them to training inputs\n",
837
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
838
+ "\n",
839
+ " # Remove the samples which are too long\n",
840
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
841
+ "\n",
842
+ " # Shuffle the dataset\n",
843
+ " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n",
844
+ "\n",
845
+ " # Repeat if necessary\n",
846
+ " ds = ds.repeat(num_epochs)\n",
847
+ "\n",
848
+ " # Build batches\n",
849
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
850
+ " return ds\n",
851
+ "\n",
852
+ " def get_validation_dataset(self, batch_size: int):\n",
853
+ " \"\"\"Build the validation dataset.\"\"\"\n",
854
+ "\n",
855
+ " # Same as the training dataset, but no shuffling and no repetition\n",
856
+ " ds = self._base_data[DatasetSplit.VALIDATION].map(\n",
857
+ " lambda x : (self._tokenize_source(x['src']),\n",
858
+ " self._tokenize_destination(x['dst']))\n",
859
+ " )\n",
860
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
861
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
862
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
863
+ " return ds"
864
+ ]
865
+ },
866
+ {
867
+ "cell_type": "markdown",
868
+ "metadata": {
869
+ "id": "WsOYxL8XXSqf"
870
+ },
871
+ "source": [
872
+ "# Try"
873
+ ]
874
+ },
875
+ {
876
+ "cell_type": "markdown",
877
+ "metadata": {
878
+ "id": "_Sq9uC15EcoZ"
879
+ },
880
+ "source": [
881
+ "Let's give it a try."
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "code",
886
+ "execution_count": null,
887
+ "metadata": {
888
+ "cellView": "form",
889
+ "id": "bYeduOaNEcoZ"
890
+ },
891
+ "outputs": [],
892
+ "source": [
893
+ "dataset_builder = MyDatasetBuilder(tokenizer, max_seq_len=4000)\n",
894
+ "ds = dataset_builder.get_train_dataset(3, 1)\n",
895
+ "ds = ds.take(2)\n",
896
+ "ds = ds.as_numpy_iterator()\n",
897
+ "for idx, example in enumerate(ds):\n",
898
+ " print(f'Example {idx}:')\n",
899
+ " for key, val in example.items():\n",
900
+ " print(f'{key}: {val}')\n",
901
+ " print()"
902
+ ]
903
+ },
904
+ {
905
+ "cell_type": "markdown",
906
+ "metadata": {
907
+ "id": "_VsT2o6JEcoZ"
908
+ },
909
+ "source": [
910
+ "## Fine tuning Griffin\n",
911
+ "\n",
912
+ "### Getting started\n",
913
+ "\n",
914
+ "First let's load the model. Use the `griffin_lib.GriffinConfig.from_flax_params_or_variables` function to automatically load the correct configuration from a checkpoint."
915
+ ]
916
+ },
917
+ {
918
+ "cell_type": "code",
919
+ "execution_count": null,
920
+ "metadata": {
921
+ "cellView": "form",
922
+ "id": "VDlfziQVEcoZ"
923
+ },
924
+ "outputs": [],
925
+ "source": [
926
+ "# Load parameters\n",
927
+ "params = recurrentgemma.load_parameters(ckpt_path, \"single_device\")\n",
928
+ "config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)\n",
929
+ "model = recurrentgemma.Griffin(config)"
930
+ ]
931
+ },
932
+ {
933
+ "cell_type": "markdown",
934
+ "metadata": {
935
+ "id": "cGbfx6XVEcoZ"
936
+ },
937
+ "source": [
938
+ "Can our model translate French ? Well let's try it out !"
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "code",
943
+ "execution_count": null,
944
+ "metadata": {
945
+ "cellView": "form",
946
+ "id": "jWr6Sea_EcoZ"
947
+ },
948
+ "outputs": [],
949
+ "source": [
950
+ "sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": null,
956
+ "metadata": {
957
+ "cellView": "form",
958
+ "id": "S6937NTjEcoZ"
959
+ },
960
+ "outputs": [],
961
+ "source": [
962
+ "output = sampler(\n",
963
+ " [\"Develop a Python code snippet that generates an abbreviated version of a given full name.\\nname = 'John Smith'\"],\n",
964
+ " # number of steps performed when generating\n",
965
+ " total_generation_steps=300,\n",
966
+ ")\n",
967
+ "print(output.text[0])"
968
+ ]
969
+ },
970
+ {
971
+ "cell_type": "markdown",
972
+ "metadata": {
973
+ "id": "0Z0CXW4REcoZ"
974
+ },
975
+ "source": [
976
+ "As expected, it didn't work. Let's see if we can get better results by fine-tuning."
977
+ ]
978
+ },
979
+ {
980
+ "cell_type": "markdown",
981
+ "metadata": {
982
+ "id": "gxf6gVGCEcoZ"
983
+ },
984
+ "source": [
985
+ "### Model forward and loss function\n",
986
+ "\n",
987
+ "The `Griffin` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:\n",
988
+ "\n",
989
+ "- `init`: Initializes the model's parameters.\n",
990
+ "\n",
991
+ "- `apply`: Executes the model's `__call__` function using a given set of parameters.\n",
992
+ "\n",
993
+ "Since are working with pre-trained weights, we won't use the `init` function.\n",
994
+ "\n",
995
+ "With it we can now build the `forward_function` which performs the forward pass and loss computation."
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "code",
1000
+ "execution_count": null,
1001
+ "metadata": {
1002
+ "cellView": "form",
1003
+ "id": "iEcV0XEEEcoZ"
1004
+ },
1005
+ "outputs": [],
1006
+ "source": [
1007
+ "def forward_and_loss_fn(\n",
1008
+ " params,\n",
1009
+ " *,\n",
1010
+ " model: recurrentgemma.Griffin,\n",
1011
+ " input_tokens: jax.Array, # Shape [B, L]\n",
1012
+ " input_mask: jax.Array, # Shape [B, L]\n",
1013
+ " positions: jax.Array, # Shape [B, L]\n",
1014
+ ") -> jax.Array:\n",
1015
+ " \"\"\"Forward pass and loss function.\n",
1016
+ "\n",
1017
+ " Args:\n",
1018
+ " params: model's input parameters.\n",
1019
+ " model: Griffin model to call.\n",
1020
+ " input_tokens: input tokens sequence, shape [B, L].\n",
1021
+ " input_mask: tokens to ignore when computing the loss, shape [B, L].\n",
1022
+ " positions: relative position of each token, shape [B, L].\n",
1023
+ "\n",
1024
+ " Returns:\n",
1025
+ " Softmax cross-entropy loss for the next-token prediction task.\n",
1026
+ " \"\"\"\n",
1027
+ " batch_size = input_tokens.shape[0]\n",
1028
+ " # Foward pass on the input data.\n",
1029
+ " # No attention cache is needed here.\n",
1030
+ " # Exclude the last step as it does not appear in the targets.\n",
1031
+ " logits, _ = model.apply(\n",
1032
+ " {\"params\": params},\n",
1033
+ " tokens=input_tokens[:, :-1],\n",
1034
+ " segment_pos=positions[:, :-1],\n",
1035
+ " cache=None,\n",
1036
+ " )\n",
1037
+ "\n",
1038
+ " # Similarly, the first token cannot be predicteds.\n",
1039
+ " target_tokens = input_tokens[:, 1:]\n",
1040
+ " target_mask = input_mask[:, 1:]\n",
1041
+ "\n",
1042
+ " # Convert the target labels into one-hot encoded vectors.\n",
1043
+ " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n",
1044
+ "\n",
1045
+ " # Don't update on unwanted tokens.\n",
1046
+ " one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]\n",
1047
+ "\n",
1048
+ " # Normalisation factor.\n",
1049
+ " norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)\n",
1050
+ "\n",
1051
+ " # Return the nll loss.\n",
1052
+ " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor"
1053
+ ]
1054
+ },
1055
+ {
1056
+ "cell_type": "markdown",
1057
+ "metadata": {
1058
+ "id": "xbxYMMWLEcoZ"
1059
+ },
1060
+ "source": [
1061
+ "We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly."
1062
+ ]
1063
+ },
1064
+ {
1065
+ "cell_type": "code",
1066
+ "execution_count": null,
1067
+ "metadata": {
1068
+ "cellView": "form",
1069
+ "id": "cPSfp7ZUEcoZ"
1070
+ },
1071
+ "outputs": [],
1072
+ "source": [
1073
+ "Params = Mapping[str, Any]\n",
1074
+ "\n",
1075
+ "def get_positions(example: jax.Array, pad_id : int) -> jax.Array:\n",
1076
+ " \"\"\"Builds the position vector from the given tokens.\"\"\"\n",
1077
+ " pad_mask = example != pad_id\n",
1078
+ " positions = jnp.cumsum(pad_mask, axis=-1)\n",
1079
+ " # Subtract one for all positions from the first valid one as they are\n",
1080
+ " # 0-indexed\n",
1081
+ " positions = positions - (positions >= 1)\n",
1082
+ " return positions\n",
1083
+ "\n",
1084
+ "@functools.partial(\n",
1085
+ " jax.jit,\n",
1086
+ " static_argnames=['model', 'optimizer'],\n",
1087
+ " donate_argnames=['params', 'opt_state'],\n",
1088
+ ")\n",
1089
+ "def train_step(\n",
1090
+ " model: recurrentgemma.Griffin,\n",
1091
+ " params: Params,\n",
1092
+ " optimizer: optax.GradientTransformation,\n",
1093
+ " opt_state: optax.OptState,\n",
1094
+ " pad_id: int,\n",
1095
+ " example: TrainingInput,\n",
1096
+ ") -> tuple[jax.Array, Params, optax.OptState]:\n",
1097
+ " \"\"\"Train step.\n",
1098
+ "\n",
1099
+ " Args:\n",
1100
+ " model: Griffin model.\n",
1101
+ " params: model's input parameters.\n",
1102
+ " optimizer: optax optimizer to use.\n",
1103
+ " opt_state: input optimizer's state.\n",
1104
+ " pad_id: id of the pad token.\n",
1105
+ " example: input batch.\n",
1106
+ "\n",
1107
+ " Returns:\n",
1108
+ " Training loss, updated parameters, updated optimizer state.\n",
1109
+ " \"\"\"\n",
1110
+ "\n",
1111
+ " positions = get_positions(example.input_tokens, pad_id)\n",
1112
+ "\n",
1113
+ " # Forward and backward passes\n",
1114
+ " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(\n",
1115
+ " params,\n",
1116
+ " model=model,\n",
1117
+ " input_tokens=example.input_tokens,\n",
1118
+ " input_mask=example.target_mask,\n",
1119
+ " positions=positions,\n",
1120
+ " )\n",
1121
+ " # Update the parameters\n",
1122
+ " updates, opt_state = optimizer.update(grads, opt_state, params)\n",
1123
+ " params = optax.apply_updates(params, updates)\n",
1124
+ "\n",
1125
+ " return train_loss, params, opt_state"
1126
+ ]
1127
+ },
1128
+ {
1129
+ "cell_type": "markdown",
1130
+ "metadata": {
1131
+ "id": "R2QXp116EcoZ"
1132
+ },
1133
+ "source": [
1134
+ "Similarly, we build a `validation_step` function without backward pass."
1135
+ ]
1136
+ },
1137
+ {
1138
+ "cell_type": "code",
1139
+ "execution_count": null,
1140
+ "metadata": {
1141
+ "cellView": "form",
1142
+ "id": "yU4oR92YEcoa"
1143
+ },
1144
+ "outputs": [],
1145
+ "source": [
1146
+ "@functools.partial(jax.jit, static_argnames=['model'])\n",
1147
+ "def validation_step(\n",
1148
+ " model: recurrentgemma.Griffin,\n",
1149
+ " params: Params,\n",
1150
+ " pad_id: int,\n",
1151
+ " example: TrainingInput,\n",
1152
+ ") -> jax.Array:\n",
1153
+ " return forward_and_loss_fn(\n",
1154
+ " params,\n",
1155
+ " model=model,\n",
1156
+ " input_tokens=example.input_tokens,\n",
1157
+ " input_mask=example.target_mask,\n",
1158
+ " positions=get_positions(example.input_tokens, pad_id),\n",
1159
+ " )"
1160
+ ]
1161
+ },
1162
+ {
1163
+ "cell_type": "markdown",
1164
+ "metadata": {
1165
+ "id": "6g6LFWJbEcoa"
1166
+ },
1167
+ "source": [
1168
+ "And now the training loop itself."
1169
+ ]
1170
+ },
1171
+ {
1172
+ "cell_type": "code",
1173
+ "execution_count": null,
1174
+ "metadata": {
1175
+ "cellView": "form",
1176
+ "id": "xT4bAqNLEcoa"
1177
+ },
1178
+ "outputs": [],
1179
+ "source": [
1180
+ "def train_loop(\n",
1181
+ " model: recurrentgemma.Griffin,\n",
1182
+ " params: Params,\n",
1183
+ " optimizer: optax.GradientTransformation,\n",
1184
+ " train_ds: Iterator[TrainingInput],\n",
1185
+ " validation_ds: Iterator[TrainingInput],\n",
1186
+ " num_steps: int | None = None,\n",
1187
+ " eval_every_n: int = 20,\n",
1188
+ "):\n",
1189
+ " opt_state = jax.jit(optimizer.init)(params)\n",
1190
+ "\n",
1191
+ " step_counter = 0\n",
1192
+ " avg_loss=0\n",
1193
+ "\n",
1194
+ " # A first round of validation loss\n",
1195
+ " n_steps_eval = 0\n",
1196
+ " eval_loss = 0\n",
1197
+ " for val_example in validation_ds.as_numpy_iterator():\n",
1198
+ " eval_loss += validation_step(\n",
1199
+ " model, params, dataset_builder._tokenizer.pad_id, val_example\n",
1200
+ " )\n",
1201
+ " n_steps_eval += 1\n",
1202
+ " print(f\"Start, validation loss: {eval_loss/n_steps_eval}\")\n",
1203
+ "\n",
1204
+ " for train_example in train_ds:\n",
1205
+ " train_loss, params, opt_state = train_step(\n",
1206
+ " model=model,\n",
1207
+ " params=params,\n",
1208
+ " optimizer=optimizer,\n",
1209
+ " opt_state=opt_state,\n",
1210
+ " pad_id=dataset_builder._tokenizer.pad_id,\n",
1211
+ " example=train_example,\n",
1212
+ " )\n",
1213
+ "\n",
1214
+ " step_counter += 1\n",
1215
+ " avg_loss += train_loss\n",
1216
+ " if step_counter % eval_every_n == 0:\n",
1217
+ " eval_loss = 0\n",
1218
+ "\n",
1219
+ " n_steps_eval = 0\n",
1220
+ " val_iterator = validation_ds.as_numpy_iterator()\n",
1221
+ " for val_example in val_iterator:\n",
1222
+ " eval_loss += validation_step(\n",
1223
+ " model,\n",
1224
+ " params,\n",
1225
+ " dataset_builder._tokenizer.pad_id,\n",
1226
+ " val_example,\n",
1227
+ " )\n",
1228
+ " n_steps_eval +=1\n",
1229
+ " avg_loss /= eval_every_n\n",
1230
+ " eval_loss /= n_steps_eval\n",
1231
+ " print(f\"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}\")\n",
1232
+ " avg_loss=0\n",
1233
+ " if num_steps is not None and step_counter > num_steps:\n",
1234
+ " break\n",
1235
+ " return params"
1236
+ ]
1237
+ },
1238
+ {
1239
+ "cell_type": "markdown",
1240
+ "metadata": {
1241
+ "id": "hJAuU6P1dGCl"
1242
+ },
1243
+ "source": [
1244
+ "Here you have to choose an optimizer. For devices with smaller memory (like the T4 GPU) we suggest to use SGD as it has a much lower memory footprint. To achieve best finetuning performance we suggest to try Adam-W. We have provided optimal hyper parameters for each optimizer for the particular task in this notebook for the '2b-it' checkpoint."
1245
+ ]
1246
+ },
1247
+ {
1248
+ "cell_type": "code",
1249
+ "execution_count": null,
1250
+ "metadata": {
1251
+ "id": "oMufclhfc-t4"
1252
+ },
1253
+ "outputs": [],
1254
+ "source": [
1255
+ "def griffin_weight_decay_mask(params_like: optax.Params) -> Any:\n",
1256
+ " # Don't put weight decay on the RGLRU, the embeddings and any biases\n",
1257
+ " def enable_weight_decay(path: list[Any], _: Any) -> bool:\n",
1258
+ " # Parameters in the LRU and embedder\n",
1259
+ " path = [dict_key.key for dict_key in path]\n",
1260
+ " if 'rg_lru' in path or 'embedder' in path:\n",
1261
+ " return False\n",
1262
+ " # All biases and scales\n",
1263
+ " if path[-1] in ('b', 'scale'):\n",
1264
+ " return False\n",
1265
+ " return True\n",
1266
+ "\n",
1267
+ " return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)\n",
1268
+ "\n",
1269
+ "optimizer_choice = \"adamw\" #@param [\"sgd\", \"adamw\"]\n",
1270
+ "\n",
1271
+ "if optimizer_choice == \"sgd\":\n",
1272
+ " optimizer = optax.sgd(learning_rate=1e-3)\n",
1273
+ " num_steps = 300\n",
1274
+ "elif optimizer_choice == \"adamw\":\n",
1275
+ " optimizer = optax.adamw(\n",
1276
+ " learning_rate=1e-4,\n",
1277
+ " b2=0.96,\n",
1278
+ " eps=1e-8,\n",
1279
+ " weight_decay=0.1,\n",
1280
+ " mask=griffin_weight_decay_mask,\n",
1281
+ " )\n",
1282
+ " num_steps = 100\n",
1283
+ " pass\n",
1284
+ "else:\n",
1285
+ " raise ValueError(f\"Unknown optimizer: {optimizer_choice}\")"
1286
+ ]
1287
+ },
1288
+ {
1289
+ "cell_type": "markdown",
1290
+ "metadata": {
1291
+ "id": "3tSwzfRdfJ_W"
1292
+ },
1293
+ "source": [
1294
+ "Finally we prepare the training and validation datasets"
1295
+ ]
1296
+ },
1297
+ {
1298
+ "cell_type": "code",
1299
+ "execution_count": null,
1300
+ "metadata": {
1301
+ "id": "0KFz-9OcfM9-"
1302
+ },
1303
+ "outputs": [],
1304
+ "source": [
1305
+ "# Small seq size so that everything fits in memory\n",
1306
+ "num_epochs = 1 #@param {type: \"integer\"}\n",
1307
+ "batch_size = 1 #@param {type: \"integer\"}\n",
1308
+ "sequence_length = 4000 #@param {type: \"integer\"}\n",
1309
+ "\n",
1310
+ "# Make the dataset builder\n",
1311
+ "tokenizer = GriffinTokenizer(vocab)\n",
1312
+ "dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)\n",
1313
+ "\n",
1314
+ "# Build the training dataset\n",
1315
+ "train_ds = dataset_builder.get_train_dataset(\n",
1316
+ " batch_size=batch_size,\n",
1317
+ " num_epochs=num_epochs,\n",
1318
+ ").as_numpy_iterator()\n",
1319
+ "\n",
1320
+ "# Build the validation dataset, with a limited number of samples for this demo\n",
1321
+ "validation_ds = dataset_builder.get_validation_dataset(\n",
1322
+ " batch_size=batch_size,\n",
1323
+ ").take(50)"
1324
+ ]
1325
+ },
1326
+ {
1327
+ "cell_type": "markdown",
1328
+ "metadata": {
1329
+ "id": "muwkf_ZgEcoa"
1330
+ },
1331
+ "source": [
1332
+ "We can now fine-tune our model on a limited number of steps."
1333
+ ]
1334
+ },
1335
+ {
1336
+ "cell_type": "code",
1337
+ "execution_count": null,
1338
+ "metadata": {
1339
+ "id": "vyuWnFY5wSlW"
1340
+ },
1341
+ "outputs": [],
1342
+ "source": [
1343
+ "trained_params = train_loop(\n",
1344
+ " model=model,\n",
1345
+ " params=params,\n",
1346
+ " optimizer=optimizer,\n",
1347
+ " train_ds=train_ds,\n",
1348
+ " validation_ds=validation_ds,\n",
1349
+ " num_steps=num_steps,\n",
1350
+ ")"
1351
+ ]
1352
+ },
1353
+ {
1354
+ "cell_type": "markdown",
1355
+ "metadata": {
1356
+ "id": "abChlybFEcod"
1357
+ },
1358
+ "source": [
1359
+ "Both the training loss and the validation's are going down. But is it working ?\n",
1360
+ "\n",
1361
+ "Let's try again with our previous example. To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\\n' and a newline character at the end. This signals the model to begin translation."
1362
+ ]
1363
+ },
1364
+ {
1365
+ "cell_type": "code",
1366
+ "execution_count": null,
1367
+ "metadata": {
1368
+ "cellView": "form",
1369
+ "id": "S5F3fk22Ecod"
1370
+ },
1371
+ "outputs": [],
1372
+ "source": [
1373
+ "sampler.params = trained_params\n",
1374
+ "output = sampler(\n",
1375
+ " [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n",
1376
+ " total_generation_steps=30,\n",
1377
+ ")\n",
1378
+ "print(output.text[0])"
1379
+ ]
1380
+ },
1381
+ {
1382
+ "cell_type": "code",
1383
+ "execution_count": null,
1384
+ "metadata": {
1385
+ "id": "FdSF-xoChOPD"
1386
+ },
1387
+ "outputs": [],
1388
+ "source": []
1389
+ }
1390
+ ],
1391
+ "metadata": {
1392
+ "accelerator": "GPU",
1393
+ "colab": {
1394
+ "collapsed_sections": [
1395
+ "iLafhtv3Rg5F",
1396
+ "m-BHqBGBVlei"
1397
+ ],
1398
+ "gpuType": "A100",
1399
+ "private_outputs": true,
1400
+ "provenance": []
1401
+ },
1402
+ "kernelspec": {
1403
+ "display_name": "Python 3",
1404
+ "name": "python3"
1405
+ },
1406
+ "language_info": {
1407
+ "codemirror_mode": {
1408
+ "name": "ipython",
1409
+ "version": 3
1410
+ },
1411
+ "file_extension": ".py",
1412
+ "mimetype": "text/x-python",
1413
+ "name": "python",
1414
+ "nbconvert_exporter": "python",
1415
+ "pygments_lexer": "ipython3",
1416
+ "version": "3.10.12"
1417
+ }
1418
+ },
1419
+ "nbformat": 4,
1420
+ "nbformat_minor": 0
1421
+ }