{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e9ca44ab-68d4-4361-a7fb-1f887f1b06c0", "metadata": { "papermill": { "duration": 20.056463, "end_time": "2023-02-01T13:28:53.560235", "exception": false, "start_time": "2023-02-01T13:28:33.503772", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "!pip install -q transformers datasets" ] }, { "cell_type": "code", "execution_count": 2, "id": "d5482d72-f55e-4b09-befc-a0b71fb0f6b3", "metadata": { "papermill": { "duration": 0.126709, "end_time": "2023-02-01T13:28:53.696755", "exception": false, "start_time": "2023-02-01T13:28:53.570046", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning: Unexpected command-line argument -f found.\n", "Warning: Unexpected command-line argument /root/.local/share/jupyter/runtime/kernel-92e2dce4-3520-4966-a7b3-b12619e1a0d7.json found.\n" ] } ], "source": [ "import valohai\n", "\n", "valohai.prepare(\n", " step='train-model',\n", " image='pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime', \n", " default_parameters={ \n", " 'epochs': 10,\n", " 'model': 'google/mt5-small',\n", " }\n", ")\n", "output_path = valohai.outputs().path('model')" ] }, { "cell_type": "code", "execution_count": 3, "id": "7d8321e3-caf8-4f1b-8f4e-568df5e9608c", "metadata": { "papermill": { "duration": 1.139645, "end_time": "2023-02-01T13:28:54.844272", "exception": false, "start_time": "2023-02-01T13:28:53.704627", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "import torch\n", "\n", "torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(torch_device)" ] }, { "cell_type": "code", "execution_count": 4, "id": "f4484a17-8ba2-45a8-b537-24c44bb5bb7c", "metadata": { "papermill": { "duration": 0.782457, "end_time": "2023-02-01T13:28:55.633345", "exception": false, "start_time": "2023-02-01T13:28:54.850888", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mon Mar 27 07:02:29 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 470.129.06 Driver Version: 470.129.06 CUDA Version: 11.4 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA RTX A6000 On | 00000000:05:00.0 Off | Off |\n", "| 30% 31C P8 15W / 300W | 3MiB / 48685MiB | 0% Default |\n", "| | | N/A |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "! nvidia-smi" ] }, { "cell_type": "code", "execution_count": 5, "id": "73334e06-3bf2-4e94-9870-fe3a487398c3", "metadata": { "papermill": { "duration": 45.306651, "end_time": "2023-02-01T13:29:40.951034", "exception": false, "start_time": "2023-02-01T13:28:55.644383", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n", "Found cached dataset wikisql (/root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "train_data = load_dataset('wikisql', split='train+validation')\n", "test_data = load_dataset('wikisql', split='test')" ] }, { "cell_type": "code", "execution_count": 6, "id": "cf5379de-aeb5-4a1c-8d23-9ad1e56dc445", "metadata": { "papermill": { "duration": 0.038407, "end_time": "2023-02-01T13:29:41.013026", "exception": false, "start_time": "2023-02-01T13:29:40.974619", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def format_dataset(example):\n", " return {'input': 'translate to SQL: ' + example['question'] + ' table ID: ' + ', '.join(str(x) for x in example['table']['header']), 'target': example['sql']['human_readable']}" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ce6feef-eab2-4b7a-86f0-c663e5790c5d", "metadata": { "papermill": { "duration": 17.729786, "end_time": "2023-02-01T13:29:58.768354", "exception": false, "start_time": "2023-02-01T13:29:41.038568", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-1ea43016a8276f85.arrow\n" ] } ], "source": [ "train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)" ] }, { "cell_type": "code", "execution_count": 8, "id": "03862b72-56e4-40ab-aae2-81604f69d608", "metadata": { "papermill": { "duration": 4.566604, "end_time": "2023-02-01T13:30:03.373278", "exception": false, "start_time": "2023-02-01T13:29:58.806674", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d/cache-b9e3da7e258b7aa5.arrow\n" ] } ], "source": [ "test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)" ] }, { "cell_type": "code", "execution_count": 9, "id": "6246e5c3-4d91-4c65-9ee9-bfc366339e97", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (0.1.97)\n", "Collecting protobuf==3.20.*\n", " Downloading protobuf-3.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n", "\u001b[K |████████████████████████████████| 1.0 MB 4.4 MB/s eta 0:00:01\n", "\u001b[?25hInstalling collected packages: protobuf\n", " Attempting uninstall: protobuf\n", " Found existing installation: protobuf 4.22.1\n", " Uninstalling protobuf-4.22.1:\n", " Successfully uninstalled protobuf-4.22.1\n", "Successfully installed protobuf-3.20.3\n" ] } ], "source": [ "!pip install sentencepiece\n", "!pip install protobuf==3.20.*" ] }, { "cell_type": "code", "execution_count": 10, "id": "f162ac75-aeda-409a-af8c-f70f5a1d7cbd", "metadata": { "papermill": { "duration": 16.204849, "end_time": "2023-02-01T13:30:19.617815", "exception": false, "start_time": "2023-02-01T13:30:03.412966", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/transformers/convert_slow_tokenizer.py:447: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n", " \"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option\"\n", "You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.\n", "Downloading pytorch_model.bin: 100%|██████████| 1.20G/1.20G [00:16<00:00, 72.6MB/s]\n", "Downloading (…)neration_config.json: 100%|██████████| 147/147 [00:00<00:00, 31.2kB/s]\n" ] } ], "source": [ "CKPT = valohai.parameters(\"model\").value\n", "from transformers import AutoTokenizer, T5ForConditionalGeneration\n", "tokenizer = AutoTokenizer.from_pretrained(CKPT)\n", "model = T5ForConditionalGeneration.from_pretrained(CKPT).to(torch_device)" ] }, { "cell_type": "code", "execution_count": 11, "id": "6e2c9c3b-dfd1-4a34-ad77-3c8f69ac4854", "metadata": { "papermill": { "duration": 2.058386, "end_time": "2023-02-01T13:30:21.722091", "exception": false, "start_time": "2023-02-01T13:30:19.663705", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Input Mean: 47.4798, %-Input > 256:0.0, %-Input > 128:0.001, %-Input > 64:0.0684 Output Mean:19.4288, %-Output > 256:0.0, %-Output > 128:0.0002, %-Output > 64:0.0004\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r" ] } ], "source": [ "# map article and summary len to dict as well as if sample is longer than 512 tokens\n", "def map_to_length(x):\n", " x[\"input_len\"] = len(tokenizer(x[\"input\"]).input_ids)\n", " x[\"input_longer_256\"] = int(x[\"input_len\"] > 256)\n", " x[\"input_longer_128\"] = int(x[\"input_len\"] > 128)\n", " x[\"input_longer_64\"] = int(x[\"input_len\"] > 64)\n", " x[\"out_len\"] = len(tokenizer(x[\"target\"]).input_ids)\n", " x[\"out_longer_256\"] = int(x[\"out_len\"] > 256)\n", " x[\"out_longer_128\"] = int(x[\"out_len\"] > 128)\n", " x[\"out_longer_64\"] = int(x[\"out_len\"] > 64)\n", " return x\n", "\n", "sample_size = 10000\n", "data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)\n", "\n", "def compute_and_print_stats(x):\n", " if len(x[\"input_len\"]) == sample_size:\n", " print(\n", " \"Input Mean: {}, %-Input > 256:{}, %-Input > 128:{}, %-Input > 64:{} Output Mean:{}, %-Output > 256:{}, %-Output > 128:{}, %-Output > 64:{}\".format(\n", " sum(x[\"input_len\"]) / sample_size,\n", " sum(x[\"input_longer_256\"]) / sample_size,\n", " sum(x[\"input_longer_128\"]) / sample_size,\n", " sum(x[\"input_longer_64\"]) / sample_size, \n", " sum(x[\"out_len\"]) / sample_size,\n", " sum(x[\"out_longer_256\"]) / sample_size,\n", " sum(x[\"out_longer_128\"]) / sample_size,\n", " sum(x[\"out_longer_64\"]) / sample_size,\n", " )\n", " )\n", "\n", "output = data_stats.map(\n", " compute_and_print_stats, \n", " batched=True,\n", " batch_size=-1,\n", ") " ] }, { "cell_type": "code", "execution_count": 12, "id": "d6b69f36-bd57-46e4-b77e-a0017ffbf64e", "metadata": { "papermill": { "duration": 0.063495, "end_time": "2023-02-01T13:30:21.834853", "exception": false, "start_time": "2023-02-01T13:30:21.771358", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# tokenize the examples\n", "def convert_to_features(example_batch):\n", " input_encodings = tokenizer.batch_encode_plus(example_batch['input'], pad_to_max_length=True, max_length=100, truncation=True)\n", " target_encodings = tokenizer.batch_encode_plus(example_batch['target'], pad_to_max_length=True, max_length=100, truncation=True)\n", "\n", " encodings = {\n", " 'input_ids': input_encodings['input_ids'], \n", " 'attention_mask': input_encodings['attention_mask'],\n", " 'labels': target_encodings['input_ids'],\n", " 'decoder_attention_mask': target_encodings['attention_mask']\n", " }\n", "\n", " return encodings " ] }, { "cell_type": "code", "execution_count": 13, "id": "67b3b61d-e1ae-435e-8f55-46fa219ea3e2", "metadata": { "papermill": { "duration": 23.172287, "end_time": "2023-02-01T13:30:45.056685", "exception": false, "start_time": "2023-02-01T13:30:21.884398", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 0%| | 0/64776 [00:00, ? examples/s]/opt/conda/lib/python3.7/site-packages/transformers/tokenization_utils_base.py:2352: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", " FutureWarning,\n", " \r" ] } ], "source": [ "train_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)\n", "test_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)\n", "\n", "columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']\n", "\n", "train_data.set_format(type='torch', columns=columns)\n", "test_data.set_format(type='torch', columns=columns)" ] }, { "cell_type": "code", "execution_count": 14, "id": "69d37693-8c5a-45c2-a9fd-dfab43ed71fa", "metadata": { "papermill": { "duration": 0.106751, "end_time": "2023-02-01T13:30:45.221681", "exception": false, "start_time": "2023-02-01T13:30:45.114930", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "from transformers import Seq2SeqTrainingArguments" ] }, { "cell_type": "code", "execution_count": 15, "id": "644e81ec-1c23-4a2d-a488-f9354c237815", "metadata": { "papermill": { "duration": 0.069207, "end_time": "2023-02-01T13:30:45.347009", "exception": false, "start_time": "2023-02-01T13:30:45.277802", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# set training arguments - Feel free to adapt it\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=output_path,\n", " per_device_train_batch_size=16,\n", " num_train_epochs=valohai.parameters(\"epochs\").value,\n", " per_device_eval_batch_size=16,\n", " predict_with_generate=True,\n", " evaluation_strategy=\"epoch\",\n", " do_train=True,\n", " do_eval=True,\n", " logging_steps=500,\n", " save_strategy=\"epoch\",\n", " #save_steps=1000,\n", " #eval_steps=1000,\n", " overwrite_output_dir=True,\n", " save_total_limit=1,\n", " load_best_model_at_end=True,\n", " push_to_hub=False\n", " #fp16=True, \n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "46d2344c-df83-4495-b700-71e1308f60f1", "metadata": { "papermill": { "duration": 4.757895, "end_time": "2023-02-01T13:30:50.160794", "exception": false, "start_time": "2023-02-01T13:30:45.402899", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "! pip install -q rouge_score" ] }, { "cell_type": "code", "execution_count": 17, "id": "63ca930c-9cd3-4880-beb6-dd44057069bb", "metadata": { "papermill": { "duration": 1.098239, "end_time": "2023-02-01T13:30:51.318015", "exception": false, "start_time": "2023-02-01T13:30:50.219776", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", " \n" ] } ], "source": [ "from datasets import load_metric\n", "rouge = load_metric(\"rouge\")\n", "\n", "def compute_metrics(pred):\n", " labels_ids = pred.label_ids\n", " pred_ids = pred.predictions\n", "\n", " # all unnecessary tokens are removed\n", " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " labels_ids[labels_ids == -100] = tokenizer.pad_token_id\n", " label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n", "\n", " rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=[\"rouge2\"])[\"rouge2\"].mid\n", "\n", " return {\n", " \"rouge2_precision\": round(rouge_output.precision, 4),\n", " \"rouge2_recall\": round(rouge_output.recall, 4),\n", " \"rouge2_fmeasure\": round(rouge_output.fmeasure, 4),\n", " }" ] }, { "cell_type": "code", "execution_count": 18, "id": "2977e566-8714-4164-b7ad-2706dbd26be8", "metadata": { "papermill": { "duration": 0.074325, "end_time": "2023-02-01T13:30:51.451387", "exception": false, "start_time": "2023-02-01T13:30:51.377062", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# instantiate trainer\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " compute_metrics=compute_metrics,\n", " train_dataset=train_data,\n", " eval_dataset=test_data,\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "id": "8dce01a3-61b2-4cb4-b5d2-319d0e946083", "metadata": { "papermill": { "duration": 227.616733, "end_time": "2023-02-01T13:34:39.125675", "exception": false, "start_time": "2023-02-01T13:30:51.508942", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Rouge2 Precision | \n", "Rouge2 Recall | \n", "Rouge2 Fmeasure | \n", "
---|---|---|---|---|---|
1 | \n", "0.103200 | \n", "0.051379 | \n", "0.901000 | \n", "0.817300 | \n", "0.849700 | \n", "
2 | \n", "0.065800 | \n", "0.038024 | \n", "0.917400 | \n", "0.838200 | \n", "0.869300 | \n", "
3 | \n", "0.054700 | \n", "0.033012 | \n", "0.923000 | \n", "0.844100 | \n", "0.875000 | \n", "
4 | \n", "0.045900 | \n", "0.030169 | \n", "0.928600 | \n", "0.847300 | \n", "0.880000 | \n", "
5 | \n", "0.040100 | \n", "0.028730 | \n", "0.930800 | \n", "0.849800 | \n", "0.882400 | \n", "
6 | \n", "0.039300 | \n", "0.027651 | \n", "0.931800 | \n", "0.850700 | \n", "0.883300 | \n", "
7 | \n", "0.036000 | \n", "0.027332 | \n", "0.932900 | \n", "0.852000 | \n", "0.884600 | \n", "
8 | \n", "0.033500 | \n", "0.026453 | \n", "0.933100 | \n", "0.852300 | \n", "0.884900 | \n", "
9 | \n", "0.032800 | \n", "0.026168 | \n", "0.934200 | \n", "0.853100 | \n", "0.885800 | \n", "
10 | \n", "0.032300 | \n", "0.026122 | \n", "0.934300 | \n", "0.853100 | \n", "0.885900 | \n", "
"
],
"text/plain": [
"