{ "cells": [ { "cell_type": "markdown", "source": [ "## Fine-tuning RoBERTa large for token classification\n", "\n", "Treats fixing commas as a NER problem, where for each token we predict whether a comma should be inserted after it. We assume input data has no commas, which ensures the input distribution is the same for the model, regardless of the types of mistakes users could make. The model would then restore the commas and leave the rest of the text intact." ], "attachments": {}, "metadata": { "datalore": { "node_id": "dyUJpYnHWmCybknBwWLe5Z", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "import torch\n", "torch.cuda.is_available()" ], "execution_count": 1, "outputs": [ { "data": { "text/plain": [ "True" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "dwzLkm8DtZ6gwwOBT2Y7rM", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "9tLzdz03okTFWD3Vy0h4vS" } } } }, { "cell_type": "code", "source": [ "from datasets import load_dataset\n", "from transformers import (\n", " AutoModelForTokenClassification,\n", " AutoTokenizer,\n", " DataCollatorForTokenClassification,\n", " TrainingArguments,\n", " Trainer,\n", ")\n", "from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType\n", "import seqeval\n", "import torch\n", "import numpy as np\n", "import re\n", "import evaluate" ], "execution_count": 2, "outputs": [], "metadata": { "datalore": { "node_id": "kmU9kledu94gR9zgAN1Ga5", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "h8AdYv2Y3tw7Yb4W3IGoxc" } } } }, { "cell_type": "code", "source": [ "model_checkpoint = \"roberta-large\"" ], "execution_count": 3, "outputs": [], "metadata": { "datalore": { "node_id": "t1hpVurdN6Ji6IPIUFlqnB", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "sGTotT7rOtZWKTPgVm4gX6" } } } }, { "cell_type": "markdown", "source": [ "We will use the wikitext dataset, since it is large and has more diverse texts than, e.g., books, with fairly a lot of commas." ], "attachments": {}, "metadata": { "datalore": { "node_id": "1L9Lgn5UOmYdIAB9VCLOUx", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "wikitext = load_dataset('wikitext', 'wikitext-103-v1') # TODO we should only load part of it, too big to train on whole anyway" ], "execution_count": 4, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "b49ef90878574db2b0ee3e832b51f0e1" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "EzocdU7ZnKbRTA6Qyb2eui" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "b73ea2c1d22f48a588d04f63ae856ad6" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "5EyBntpe2omeh7zplkVxbz" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "4df35e53a0cc41118cf794e623c9e1ae" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "YPtbQNg3pL68UM9OfASsNu" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "d4f75e8ad8ae4b729303ccbe2283de3c" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "mYXcmC0ka6gJT8jwrSYZGF" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "10f1853410664a0b9cc9ea8693a64807" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "tvLSuVyurIDaBZbUl4aC8p" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "f2d9c2140b3f4e068382f63a376c8cb1" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "SjJATq4Bu7Rd02ryJBysxk" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "f69e84b3b9634448b0141e11236366c3" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "kLgRm1cp86gQOV4nvn0nsn" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "cmS6SmnZ5bCm9dS4HcW3yj", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "yuWdT9AFjYlOMQTPOZZIUT" } } } }, { "cell_type": "code", "source": [ "wikitext" ], "execution_count": 5, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " test: Dataset({\n", " features: ['text'],\n", " num_rows: 4358\n", " })\n", " train: Dataset({\n", " features: ['text'],\n", " num_rows: 1801350\n", " })\n", " validation: Dataset({\n", " features: ['text'],\n", " num_rows: 3760\n", " })\n", "})" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "gL2NZUdQOETnRVwSrzZMK6", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "gE1CBjdhArw5iApIeceEYR" } } } }, { "cell_type": "markdown", "source": [ "### Preprocessing" ], "attachments": {}, "metadata": { "datalore": { "node_id": "TEOArfY5ox2vtfEeWFofK8", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "label_list = [\n", " \"O\",\n", " \"B-COMMA\",\n", "]\n", "id2label = {\n", " 0: \"O\",\n", " 1: \"B-COMMA\"\n", "}\n", "label2id = {\n", " \"O\": 0,\n", " \"B-COMMA\": 1\n", "}" ], "execution_count": 6, "outputs": [], "metadata": { "datalore": { "node_id": "mtshTbrO5e9UwX0mQbohcf", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "pW8zIwKks0yGy8XVO6sU9T" } } } }, { "cell_type": "markdown", "source": [ "Wikitext is already space tokenized. We use that information, remove commas from the data and append a COMMA tag to the preceding token." ], "attachments": {}, "metadata": { "datalore": { "node_id": "cwIAYXqprU9uW3eS1d7ii6", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "def map_wikitext(x) -> dict:\n", " tokens = x[\"text\"].split()\n", " new_tokens, labels = [], []\n", " for token in tokens:\n", " if ',' in token:\n", " if not labels:\n", " print(x[\"text\"])\n", " else:\n", " labels[-1] = label2id[\"B-COMMA\"]\n", " else:\n", " labels.append(label2id[\"O\"])\n", " new_tokens.append(token)\n", " return {'tokens': new_tokens, 'tags': labels}" ], "execution_count": 7, "outputs": [], "metadata": { "datalore": { "node_id": "HD0Jlp59AnlQFmXvPheL48", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "u4ni0ncMGCkPMDSQTb2Fjz" } } } }, { "cell_type": "code", "source": [ "wikitext[\"train\"][3]" ], "execution_count": 8, "outputs": [ { "data": { "text/plain": [ "{'text': ' Senjō no Valkyria 3 : Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the \" Nameless \" , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" Raven \" . \\n'}" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "vURwu7beQTkZL5mzHlfAjT", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "i6bHaRjA1inmjze1Y7w5KP" } } } }, { "cell_type": "markdown", "source": [ "Other than mapping, we also filter empty texts (25% in wikitext), and very long paragraphs. We print texts starting with a comma, and remove the initial comma since we cannot represent it and assume no sentene should start with a comma." ], "attachments": {}, "metadata": { "datalore": { "node_id": "YjN1hYHVnzLmr0byD0haSW", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "wikitext_mapped = wikitext.filter(lambda x: x[\"text\"] and len(x[\"text\"].split()) < 512).map(map_wikitext)" ], "execution_count": 9, "outputs": [ { "name": "stdout", "text": [ " , \n", "\n", " , the slight increase in comparison loop efficiency does not compensate for the extra iteration . Knuth 1998 gives a value of \n" ], "output_type": "stream" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "4570a228ecd043198331a1378d9f9621" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "zkpNXnOddd31KKr2PbuQbu" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "73a77fce14294f9d9d1de72cbdb93bac" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "oMvc26V11XATUcJL3MpaSx" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "85c2b88a530b4b10b836185d0fbc4f38" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "3oLW1dUwUPqA7rzPEXv5I8" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "9024a3a6a42c4bc3b0afdc30557ad09d" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "MRLrkIa0FFrxwmmKmV2Hc2" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "1e08db6867a44093b73a3d78f92612d6" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "3eCfbu6IspjzG6HvvX24Td" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "e26308dff4c3457794b8cca3bb455485" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "JxGyWtrVY3EqxihT5S21Z6" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "prJX5QZg1VmAX6pN0mnyom", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "48CMjhbFyYMCFfudw3NWLS" } } } }, { "cell_type": "code", "source": [ "wikitext_mapped[\"train\"][1]" ], "execution_count": 10, "outputs": [ { "data": { "text/plain": [ "{'text': ' Senjō no Valkyria 3 : Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the \" Nameless \" , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" Raven \" . \\n',\n", " 'tokens': ['Senjō',\n", " 'no',\n", " 'Valkyria',\n", " '3',\n", " ':',\n", " '',\n", " 'Chronicles',\n", " '(',\n", " 'Japanese',\n", " ':',\n", " '戦場のヴァルキュリア3',\n", " 'lit',\n", " '.',\n", " 'Valkyria',\n", " 'of',\n", " 'the',\n", " 'Battlefield',\n", " '3',\n", " ')',\n", " 'commonly',\n", " 'referred',\n", " 'to',\n", " 'as',\n", " 'Valkyria',\n", " 'Chronicles',\n", " 'III',\n", " 'outside',\n", " 'Japan',\n", " 'is',\n", " 'a',\n", " 'tactical',\n", " 'role',\n", " '@-@',\n", " 'playing',\n", " 'video',\n", " 'game',\n", " 'developed',\n", " 'by',\n", " 'Sega',\n", " 'and',\n", " 'Media.Vision',\n", " 'for',\n", " 'the',\n", " 'PlayStation',\n", " 'Portable',\n", " '.',\n", " 'Released',\n", " 'in',\n", " 'January',\n", " '2011',\n", " 'in',\n", " 'Japan',\n", " 'it',\n", " 'is',\n", " 'the',\n", " 'third',\n", " 'game',\n", " 'in',\n", " 'the',\n", " 'Valkyria',\n", " 'series',\n", " '.',\n", " 'Employing',\n", " 'the',\n", " 'same',\n", " 'fusion',\n", " 'of',\n", " 'tactical',\n", " 'and',\n", " 'real',\n", " '@-@',\n", " 'time',\n", " 'gameplay',\n", " 'as',\n", " 'its',\n", " 'predecessors',\n", " 'the',\n", " 'story',\n", " 'runs',\n", " 'parallel',\n", " 'to',\n", " 'the',\n", " 'first',\n", " 'game',\n", " 'and',\n", " 'follows',\n", " 'the',\n", " '\"',\n", " 'Nameless',\n", " '\"',\n", " 'a',\n", " 'penal',\n", " 'military',\n", " 'unit',\n", " 'serving',\n", " 'the',\n", " 'nation',\n", " 'of',\n", " 'Gallia',\n", " 'during',\n", " 'the',\n", " 'Second',\n", " 'Europan',\n", " 'War',\n", " 'who',\n", " 'perform',\n", " 'secret',\n", " 'black',\n", " 'operations',\n", " 'and',\n", " 'are',\n", " 'pitted',\n", " 'against',\n", " 'the',\n", " 'Imperial',\n", " 'unit',\n", " '\"',\n", " '',\n", " 'Raven',\n", " '\"',\n", " '.'],\n", " 'tags': [0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 1,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0,\n", " 0]}" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "DCZXZMrMQzYcH8CUXHrd1E", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "jCn3aXRNlHZ54hwSvCOPY1" } } } }, { "cell_type": "markdown", "source": [], "attachments": {}, "metadata": { "datalore": { "node_id": "ucTrLhx3rypOaFAVs5dAeO", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "seqeval = evaluate.load(\"seqeval\")" ], "execution_count": 11, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "789a5c3db5704cd288d1e9f55d758813" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "K630mpF3l5do8ry0ZFvxvF" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "tW4ZovAZV62tVCQ5pZtFd0", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "u9kmEPsjJjMXH8QxYxCkzI" } } } }, { "cell_type": "code", "source": [ "# TODO only compute for B-COMMA, not overall\n", "def compute_metrics(p):\n", " predictions, labels = p\n", " predictions = np.argmax(predictions, axis=2)\n", "\n", " true_predictions = [\n", " [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n", " for prediction, label in zip(predictions, labels)\n", " ]\n", " true_labels = [\n", " [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n", " for prediction, label in zip(predictions, labels)\n", " ]\n", "\n", " results = seqeval.compute(predictions=true_predictions, references=true_labels)\n", " return {\n", " \"precision\": results[\"overall_precision\"],\n", " \"recall\": results[\"overall_recall\"],\n", " \"f1\": results[\"overall_f1\"],\n", " \"accuracy\": results[\"overall_accuracy\"],\n", " }" ], "execution_count": 12, "outputs": [], "metadata": { "datalore": { "node_id": "rRWP0oFCa1osvkitTVt0mM", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "w9WiJH27gOJfWxpv8b8zZU" } } } }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained('roberta-large', add_prefix_space=True)\n", "tokenizer" ], "execution_count": 13, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "980c254886c1489197bd9ca70d207508" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "nF2Su6O3raZwjaNRPC85xu" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "d260f0b9f950478c907b62576ddac227" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "dEDwPE4KTC0tVSGuAcl72X" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "a656fa8f2d91428eb812643b8227dc16" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "w7Ii5jWzd4jT9idw2faCzz" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "396a09f765c74d01a69abd7c2b4c6dfc" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "n3ZrIydi6O5i4meoL30arw" } } }, "output_type": "display_data" }, { "data": { "text/plain": [ "RobertaTokenizerFast(name_or_path='roberta-large', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': '', 'sep_token': '', 'pad_token': '', 'cls_token': '', 'mask_token': AddedToken(\"\", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "NJFxvxn5BnPxJcgxtURDRo", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "KFRHM0usfnGQnFdbZihU7B" } } } }, { "cell_type": "markdown", "source": [ "We need to map the space-tokenized wikitext to the roberta tokenization, together with the token tags. -100 is ignored by PyTorch during gradient computation, and is commonly used for special tokens ( and such) and additional tokens that appear in the middle of words due to wordpiece tokenization." ], "attachments": {}, "metadata": { "datalore": { "node_id": "I0EKhbKB1hQPTOpfNyatvD", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "def tokenize_and_align_labels(examples):\n", " tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n", "\n", " labels = []\n", " for i, label in enumerate(examples[f\"tags\"]):\n", " word_ids = tokenized_inputs.word_ids(batch_index=i)\n", " previous_word_idx = None\n", " label_ids = []\n", " for word_idx in word_ids:\n", " if word_idx is None:\n", " label_ids.append(-100)\n", " elif word_idx != previous_word_idx:\n", " label_ids.append(label[word_idx])\n", " else:\n", " label_ids.append(-100)\n", " previous_word_idx = word_idx\n", " labels.append(label_ids)\n", "\n", " tokenized_inputs[\"labels\"] = labels\n", " return tokenized_inputs" ], "execution_count": 14, "outputs": [], "metadata": { "datalore": { "node_id": "39laMNR7YdJWePIs3R1jgF", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "Uegyfkvr1Grti6ZGTnIIEw" } } } }, { "cell_type": "code", "source": [ "tokenized_wikitext = wikitext_mapped.map(tokenize_and_align_labels, batched=True)" ], "execution_count": 15, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "d1c9b314cd8f4bddbc673a01c7f8d371" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "40xVPDkPm35WbYnvA1y8Ts" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "ec21ec40a59449edb8f7d9b8c36c9263" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "bTYqFsVN6Du01ntyrDTYIO" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "40fde6fe7da148ecaed3eb5df360c348" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "JZhIN2xMwaecKDza6S8uZV" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "Cxvaj6ae2iDlWrNtdzFvtR", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "trpuWhYCfTALVnH0ffydT9" } } } }, { "cell_type": "code", "source": [ "tokenized_wikitext = tokenized_wikitext.remove_columns('text')" ], "execution_count": 16, "outputs": [], "metadata": { "datalore": { "node_id": "ZRFAWku3LxAczcB8WGMM8A", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "7NEFb43JKeUw4DZjeR1Qe8" } } } }, { "cell_type": "code", "source": [ "for input_id, label in zip(tokenized_wikitext[\"train\"][1]['input_ids'], tokenized_wikitext[\"train\"][1]['labels']):\n", " print(tokenizer.convert_ids_to_tokens(input_id), id2label[label])" ], "execution_count": 17, "outputs": [ { "name": "stdout", "text": [ " -100\n", "ĠSen 0\n", "j -100\n", "Åį -100\n", "Ġno 0\n", "ĠV 0\n", "alky -100\n", "ria -100\n", "Ġ3 0\n", "Ġ: 0\n", " 0\n", "ĠChronicles 0\n", "Ġ( 0\n", "ĠJapanese 0\n", "Ġ: 0\n", "Ġæ 1\n", "Ī -100\n", "¦ -100\n", "å -100\n", "ł -100\n", "´ -100\n", "ãģ® -100\n", "ãĥ´ãĤ¡ -100\n", "ãĥ« -100\n", "ãĤŃ -100\n", "ãĥ¥ -100\n", "ãĥª -100\n", "ãĤ¢ -100\n", "3 -100\n", "Ġlit 0\n", "Ġ. 0\n", "ĠV 0\n", "alky -100\n", "ria -100\n", "Ġof 0\n", "Ġthe 0\n", "ĠBattlefield 0\n", "Ġ3 0\n", "Ġ) 1\n", "Ġcommonly 0\n", "Ġreferred 0\n", "Ġto 0\n", "Ġas 0\n", "ĠV 0\n", "alky -100\n", "ria -100\n", "ĠChronicles 0\n", "ĠIII 0\n", "Ġoutside 0\n", "ĠJapan 1\n", "Ġis 0\n", "Ġa 0\n", "Ġtactical 0\n", "Ġrole 0\n", "Ġ@ 0\n", "- -100\n", "@ -100\n", "Ġplaying 0\n", "Ġvideo 0\n", "Ġgame 0\n", "Ġdeveloped 0\n", "Ġby 0\n", "ĠSega 0\n", "Ġand 0\n", "ĠMedia 0\n", ". -100\n", "Vision -100\n", "Ġfor 0\n", "Ġthe 0\n", "ĠPlayStation 0\n", "ĠPortable 0\n", "Ġ. 0\n", "ĠReleased 0\n", "Ġin 0\n", "ĠJanuary 0\n", "Ġ2011 0\n", "Ġin 0\n", "ĠJapan 1\n", "Ġit 0\n", "Ġis 0\n", "Ġthe 0\n", "Ġthird 0\n", "Ġgame 0\n", "Ġin 0\n", "Ġthe 0\n", "ĠV 0\n", "alky -100\n", "ria -100\n", "Ġseries 0\n", "Ġ. 0\n", "ĠEmploy 0\n", "ing -100\n", "Ġthe 0\n", "Ġsame 0\n", "Ġfusion 0\n", "Ġof 0\n", "Ġtactical 0\n", "Ġand 0\n", "Ġreal 0\n", "Ġ@ 0\n", "- -100\n", "@ -100\n", "Ġtime 0\n", "Ġgameplay 0\n", "Ġas 0\n", "Ġits 0\n", "Ġpredecessors 1\n", "Ġthe 0\n", "Ġstory 0\n", "Ġruns 0\n", "Ġparallel 0\n", "Ġto 0\n", "Ġthe 0\n", "Ġfirst 0\n", "Ġgame 0\n", "Ġand 0\n", "Ġfollows 0\n", "Ġthe 0\n", "Ġ\" 0\n", "ĠNam 0\n", "eless -100\n", "Ġ\" 1\n", "Ġa 0\n", "Ġpenal 0\n", "Ġmilitary 0\n", "Ġunit 0\n", "Ġserving 0\n", "Ġthe 0\n", "Ġnation 0\n", "Ġof 0\n", "ĠGall 0\n", "ia -100\n", "Ġduring 0\n", "Ġthe 0\n", "ĠSecond 0\n", "ĠEuro 0\n", "pan -100\n", "ĠWar 0\n", "Ġwho 0\n", "Ġperform 0\n", "Ġsecret 0\n", "Ġblack 0\n", "Ġoperations 0\n", "Ġand 0\n", "Ġare 0\n", "Ġpitted 0\n", "Ġagainst 0\n", "Ġthe 0\n", "ĠImperial 0\n", "Ġunit 0\n", "Ġ\" 0\n", " 0\n", "ĠRaven 0\n", "Ġ\" 0\n", "Ġ. 0\n", " -100\n" ], "output_type": "stream" } ], "metadata": { "datalore": { "node_id": "yFJf4KbWqy5NZBDOa5CUJ7", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "PeaXXWTbo8nx301q3J7LYT" } } } }, { "cell_type": "markdown", "source": [ "The collator automatically handles padding the tokens and labels inside batches" ], "attachments": {}, "metadata": { "datalore": { "node_id": "VWf2JoJ24cwXU3AclqyUcM", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)" ], "execution_count": 18, "outputs": [], "metadata": { "datalore": { "node_id": "zswKKFPLgf53urFtHqKcTk", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "Tx4nI6q73aLba0qtxPM3HO" } } } }, { "cell_type": "markdown", "source": [ "### Training" ], "attachments": {}, "metadata": { "datalore": { "node_id": "JOUeiDhDqOWHLSWnNxMUb8", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "\n", "model = AutoModelForTokenClassification.from_pretrained(\n", " model_checkpoint, num_labels=len(label_list), id2label=id2label, label2id=label2id\n", ")" ], "execution_count": 19, "outputs": [ { "name": "stderr", "text": [ "Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ], "output_type": "stream" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "02c337abd8a14ec1a9b812b83c5996f7" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "BKS3VlmQIh5accoRARgGYI" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "rbqrUuMV3GkVdI2xPui57k", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "ej3yZyRCtCDWey9uBUDZ8i" } } } }, { "cell_type": "code", "source": [ "peft_config = LoraConfig(\n", " task_type=TaskType.TOKEN_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias=\"all\"\n", ")" ], "execution_count": 20, "outputs": [], "metadata": { "datalore": { "node_id": "2OJ85GJzkPEfzNF1FZOsWr", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "rPZGw8diH0GN6aP7FKP8ca" } } } }, { "cell_type": "code", "source": [ "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()" ], "execution_count": 21, "outputs": [ { "name": "stdout", "text": [ "trainable params: 1,848,324 || all params: 355,887,108 || trainable%: 0.519356829301049\n" ], "output_type": "stream" } ], "metadata": { "datalore": { "node_id": "unCtyhS3AGHsrOijLzaXZI", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "9b8k4ntwhjIkmCPW6b5sUf" } } } }, { "cell_type": "code", "source": [ "lr = 1e-3\n", "batch_size = 8" ], "execution_count": 22, "outputs": [], "metadata": { "datalore": { "node_id": "xY131yxZKOOEzt1z0XWzQS", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "Q4UIhvaDzVEpSjVk1zzHDy" } } } }, { "cell_type": "code", "source": [ "training_args = TrainingArguments(\n", " output_dir=\"roberta-large-lora-token-classification\",\n", " learning_rate=lr,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " gradient_accumulation_steps=4,\n", " warmup_steps=200,\n", " max_steps=20000,\n", " logging_steps = 10,\n", " save_steps=100,\n", " save_total_limit=3,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"steps\",\n", " eval_steps=100,\n", " save_strategy=\"steps\",\n", " load_best_model_at_end=True,\n", ")" ], "execution_count": 23, "outputs": [], "metadata": { "datalore": { "node_id": "ZPXLuGOcwlGRrrSv3yuRcc", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "sgKkzWyaxUFUAVeYMkmETq" } } } }, { "cell_type": "code", "source": [ "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_wikitext[\"train\"],\n", " eval_dataset=tokenized_wikitext[\"validation\"],\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "trainer.train()" ], "execution_count": 24, "outputs": [ { "name": "stderr", "text": [ "/opt/python/envs/default/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ], "output_type": "stream" }, { "ename": "KeyboardInterrupt", "evalue": "KeyboardInterrupt: ", "traceback": [ "\u001B[0;31m---------------------------------------------------------------------------", "Traceback (most recent call last)", " at line 11 in ", " at line 1539 in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)", " at line 1809 in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)", " at line 2665 in training_step(self, model, inputs)", " at line 1923 in backward(self, loss, **kwargs)", " at line 487 in backward(self, gradient, retain_graph, create_graph, inputs)", " at line 200 in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)", "KeyboardInterrupt: " ], "output_type": "error" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [ 4605/20000 5:12:54 < 17:26:33, 0.25 it/s, Epoch 0.13/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossPrecisionRecallF1Accuracy
1000.0824000.0711840.7381820.7532000.7456150.973547
2000.0622000.0515190.8037000.8505250.8264500.981614
3000.0491000.0446370.8217390.8585480.8397400.983133
4000.0460000.0432860.8271630.8556830.8411810.983369
5000.0497000.0434000.8159750.8732570.8436450.983340
6000.0479000.0439470.7902650.9086910.8453510.982887
7000.0425000.0407060.8465080.8438400.8451710.984087
8000.0455000.0409630.8459990.8452720.8456360.984116
9000.0500000.0424530.8522650.8319960.8420090.983929
10000.0508000.0428360.8603580.8032470.8308220.983163
11000.0491000.0432220.8050930.8967530.8484550.983512
12000.0434000.0430330.8724730.8037250.8366890.983851
13000.0439000.0396720.8492820.8476600.8484700.984416
14000.0499000.0418150.8850530.7883480.8339060.983836
15000.0447000.0410550.8498760.8505250.8502000.984573
16000.0451000.0405920.8479700.8539640.8509570.984603
17000.0429000.0404260.8374040.8711560.8539460.984662
18000.0448000.0411550.8077390.8972300.8501360.983718
19000.0506000.0398290.8660080.8290350.8471190.984598
20000.0464000.0390290.8478220.8551100.8514500.984642
21000.0442000.0389470.8460890.8584530.8522260.984677
22000.0399000.0396190.8247040.8793700.8511600.984170
23000.0454000.0403540.8175360.8896850.8520860.984102
24000.0451000.0407090.8106620.8845270.8459850.983423
25000.0437000.0409590.8291690.8780320.8529020.984411
26000.0410000.0394870.8233210.8874880.8542010.984406
27000.0466000.0410660.8199550.8751670.8466620.983684
28000.0474000.0398010.8366340.8716330.8537750.984632
29000.0455000.0387570.8451300.8553010.8501850.984485
30000.0443000.0393740.8236180.8852910.8533420.984338
31000.0429000.0412260.8124020.8909260.8498540.983797
32000.0439000.0385500.8523170.8483290.8503180.984628
33000.0420000.0407010.8224460.8826170.8514700.984151
34000.0420000.0405630.8259610.8739260.8492670.984033
35000.0480000.0396900.8576770.8322830.8447890.984259
36000.0440000.0398750.8193560.8846230.8507390.984023
37000.0433000.0396580.8723380.8177650.8441710.984460
38000.0464000.0391440.8156610.8904490.8514160.984003
39000.0383000.0399640.8456760.8583570.8519690.984647
40000.0507000.0387390.8368610.8647560.8505800.984362
41000.0425000.0388630.8439160.8670490.8553260.984903
42000.0400000.0397410.8106940.8978030.8520280.983949
43000.0408000.0383870.8781700.8068770.8410150.984298
44000.0448000.0385180.8629830.8367720.8496750.984760
45000.0427000.0391140.8753720.8157590.8445150.984539
46000.0400000.0399470.8577520.8454630.8515630.984829
46040.0400000.0394230.8580570.8452720.8516170.984839
46040.0400000.0376300.8471590.8514300.8492890.984868

" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "ONnqviSgLxFGZ6VsodfzN7", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "pMxR5mn4hMefEjvfv5e39e" } } } }, { "cell_type": "markdown", "source": [ "### Saving and evaluating the model" ], "attachments": {}, "metadata": { "datalore": { "node_id": "H8JNhgIpWGBMp6DTrCC4WC", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "trainer.evaluate(tokenized_wikitext[\"test\"])" ], "execution_count": 26, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.037630438804626465,\n", " 'eval_precision': 0.8471585502984171,\n", " 'eval_recall': 0.8514300617230288,\n", " 'eval_f1': 0.8492889351370101,\n", " 'eval_accuracy': 0.9848677451373048}" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "EdtPuNsMTuWUVQ3AqwHeDh", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ], "execution_count": 27, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "aee257c0976d425da20f443a3973cf36" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "WzG1d54t7f6rHW8IyvyquW" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "B6WEXUADh8XfyBBfHBk309", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "hub_name = \"klasocki/roberta-large-lora-ner-comma-fixer\"" ], "execution_count": 28, "outputs": [], "metadata": { "datalore": { "node_id": "Dyz7jkJX7CngtBG5SfJB47", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "model.push_to_hub(hub_name)" ], "execution_count": 29, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "2609923ebe3d473eaebc5126eb653fd6" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "c03DMHM7RodIRbeCxJUgzo" } } }, "output_type": "display_data" }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/klasocki/roberta-large-lora-ner-comma-fixer/commit/b6e99b176b6814a75e841edcfaa8fef649feaf31', commit_message='Upload model', commit_description='', oid='b6e99b176b6814a75e841edcfaa8fef649feaf31', pr_url=None, pr_revision=None, pr_num=None)" ] }, "metadata": {}, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "bvJvzxpWWmNcRzwP5CJpkc", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "markdown", "source": [ "### Inference" ], "attachments": {}, "metadata": { "datalore": { "node_id": "XTrG9hb1fhJ5oV8J3Nm26w", "type": "MD", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "peft_model_id = hub_name\n", "config = PeftConfig.from_pretrained(peft_model_id)\n", "inference_model = AutoModelForTokenClassification.from_pretrained(\n", " config.base_model_name_or_path, num_labels=len(label_list), id2label=id2label, label2id=label2id\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n", "model = PeftModel.from_pretrained(inference_model, peft_model_id)" ], "execution_count": 30, "outputs": [ { "name": "stderr", "text": [ "Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ], "output_type": "stream" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "ea85d240073e4115b74c2c2f76a976de" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "oRXbz3MAjR4qE4rpRi3lNV" } } }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "6bba9f6fa84a49c5b402fcc40b9c3a20" } }, "metadata": { "application/vnd.jupyter.widget-view+json": { "datalore": { "widget_id": "oOD8zOJnD3w3IRrLQzSxrm" } } }, "output_type": "display_data" } ], "metadata": { "datalore": { "node_id": "4Vxwl8BqSlNJ4tt4aQKyu7", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "text = \"This text should have commas here here and there however it does not.\"\n", "inputs = tokenizer(text, return_tensors=\"pt\")" ], "execution_count": 34, "outputs": [], "metadata": { "datalore": { "node_id": "vBfrIMnQntSHs406eTmIvN", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [ "with torch.no_grad():\n", " logits = model(**inputs).logits\n", "\n", "tokens = inputs.tokens()\n", "predictions = torch.argmax(logits, dim=2)\n", "\n", "for token, prediction in zip(tokens, predictions[0].numpy()):\n", " print((token, model.config.id2label[prediction]))" ], "execution_count": 35, "outputs": [ { "name": "stdout", "text": [ "('', 'O')\n", "('This', 'O')\n", "('Ġtext', 'O')\n", "('Ġshould', 'O')\n", "('Ġhave', 'O')\n", "('Ġcomm', 'O')\n", "('as', 'O')\n", "('Ġhere', 'B-COMMA')\n", "('Ġhere', 'O')\n", "('Ġand', 'O')\n", "('Ġthere', 'B-COMMA')\n", "('Ġhowever', 'O')\n", "('Ġit', 'O')\n", "('Ġdoes', 'O')\n", "('Ġnot', 'O')\n", "('.', 'O')\n", "('', 'O')\n" ], "output_type": "stream" } ], "metadata": { "datalore": { "node_id": "jlEMGCI1ZpLhX3dBpF7DrB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "cell_type": "code", "source": [], "execution_count": null, "outputs": [], "metadata": { "datalore": { "node_id": "pWxG1H7AcqZqzJ543WdvMH", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } } ], "metadata": { "widgets": { "application/vnd.jupyter.widget-state+json": { "version_major": 2, "version_minor": 0, "state": { "1dfa3596c41a47239cd42acaffa03ddc": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "c6b4f1a655334da1976a7bb87d9739a2": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "cd790e68b269453e9ab5b9a961a64279": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_1dfa3596c41a47239cd42acaffa03ddc", "style": "IPY_MODEL_c6b4f1a655334da1976a7bb87d9739a2", "value": "Token is valid (permission: write)." } }, "b6f37cc798714f57b51bcd8284fe5cf3": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "5fab119f434047038d02b468858f9565": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "81734d38280d4abdb9acc5295c143cc9": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_b6f37cc798714f57b51bcd8284fe5cf3", "style": "IPY_MODEL_5fab119f434047038d02b468858f9565", "value": "\u001B[1m\u001B[31mCannot authenticate through git-credential as no helper is defined on your machine." } }, "22220074fffd4176b02ec067dd3b5ae8": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "c4ab8887bd204069bb783ff6e8ae7723": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "43ba6d98d5a849188b48919aa6494396": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_22220074fffd4176b02ec067dd3b5ae8", "style": "IPY_MODEL_c4ab8887bd204069bb783ff6e8ae7723", "value": "You might have to re-authenticate when pushing to the Hugging Face Hub." } }, "e0a3c0f366ba468e9cb4aded19c802b4": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "2bbcd517c13a47fb88dd1ec65120abd1": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "ce593c77612d4e7395f6a9b880f84460": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_e0a3c0f366ba468e9cb4aded19c802b4", "style": "IPY_MODEL_2bbcd517c13a47fb88dd1ec65120abd1", "value": "Run the following command in your terminal in case you want to set the 'store' credential helper as default." } }, "53609e0450164573a5604a184e504349": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "96288508db244ae390d26c1cf739ff73": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "242cefc7ee344b64ae38e8d17d756437": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_53609e0450164573a5604a184e504349", "style": "IPY_MODEL_96288508db244ae390d26c1cf739ff73", "value": "git config --global credential.helper store" } }, "10df0fb54e6c40e7bfd704e051792c5c": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "1b4886b66e8a4041b3782bc1387e8ec8": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "c27034c532f9404fb5319286abdeb0a8": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_10df0fb54e6c40e7bfd704e051792c5c", "style": "IPY_MODEL_1b4886b66e8a4041b3782bc1387e8ec8", "value": "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001B[0m" } }, "7c0fd075376144b790c33685d4c71e31": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "4ed78afe8cfe41d4be4204980b287005": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "32b669c3e94547d984bd0cddfef27b47": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_7c0fd075376144b790c33685d4c71e31", "style": "IPY_MODEL_4ed78afe8cfe41d4be4204980b287005", "value": "Token has not been saved to git credential helper." } }, "9386980fa9684653a205701b24f4c3d7": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "44aadab07124477781bcb65fe9e681ba": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "15ddc7ac585c49ae8aae232b176ddc7d": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_9386980fa9684653a205701b24f4c3d7", "style": "IPY_MODEL_44aadab07124477781bcb65fe9e681ba", "value": "Your token has been saved to /home/datalore/.cache/huggingface/token" } }, "41e8c5786b2041b3b34410041bc9f88f": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "f949e7e6496747b396e00d0a23de8b64": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "74e374e308fe451ba997207ca19244f3": { "model_name": "LabelModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_41e8c5786b2041b3b34410041bc9f88f", "style": "IPY_MODEL_f949e7e6496747b396e00d0a23de8b64", "value": "Login successful" } }, "20fc5c16124b4e6a805073c5de94b74a": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": { "align_items": "center", "display": "flex", "flex_flow": "column", "width": "50%" } }, "aee257c0976d425da20f443a3973cf36": { "model_name": "VBoxModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "children": [ "IPY_MODEL_cd790e68b269453e9ab5b9a961a64279", "IPY_MODEL_81734d38280d4abdb9acc5295c143cc9", "IPY_MODEL_43ba6d98d5a849188b48919aa6494396", "IPY_MODEL_ce593c77612d4e7395f6a9b880f84460", "IPY_MODEL_242cefc7ee344b64ae38e8d17d756437", "IPY_MODEL_c27034c532f9404fb5319286abdeb0a8", "IPY_MODEL_32b669c3e94547d984bd0cddfef27b47", "IPY_MODEL_15ddc7ac585c49ae8aae232b176ddc7d", "IPY_MODEL_74e374e308fe451ba997207ca19244f3" ], "layout": "IPY_MODEL_20fc5c16124b4e6a805073c5de94b74a" } }, "ba6b997bc4f749368bcf3c7d75e59dd1": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "b53ebb5e12084d9fa25195fc22715ac8": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "9a00dc77ccb844faa7a7f66a2f58b942": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_ba6b997bc4f749368bcf3c7d75e59dd1", "style": "IPY_MODEL_b53ebb5e12084d9fa25195fc22715ac8", "value": "adapter_model.bin: 100%" } }, "eb326383bb0f4da9935c37eb85073a40": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "6c4558d629c34e689df32a021745ead8": { "model_name": "ProgressStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "791c2cbe9ba643139fbc4119efeed65a": { "model_name": "FloatProgressModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "bar_style": "success", "layout": "IPY_MODEL_eb326383bb0f4da9935c37eb85073a40", "max": 7490973, "style": "IPY_MODEL_6c4558d629c34e689df32a021745ead8", "value": 7490973 } }, "3b922c305a7d48bab391317cd657f8d5": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "1440148cee1e4c8486a2abebbc7faa23": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "e6529cb132894cfab125b482198822b8": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_3b922c305a7d48bab391317cd657f8d5", "style": "IPY_MODEL_1440148cee1e4c8486a2abebbc7faa23", "value": " 7.49M/7.49M [00:04<00:00, 2.26MB/s]" } }, "872fe87fd82c40a89442998ff471b881": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "2609923ebe3d473eaebc5126eb653fd6": { "model_name": "HBoxModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "children": [ "IPY_MODEL_9a00dc77ccb844faa7a7f66a2f58b942", "IPY_MODEL_791c2cbe9ba643139fbc4119efeed65a", "IPY_MODEL_e6529cb132894cfab125b482198822b8" ], "layout": "IPY_MODEL_872fe87fd82c40a89442998ff471b881" } }, "335387903ecb43a586d049bed0fbc012": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "1837a2056268420cb2d90e740c606892": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "e1ee514268c34c298c2641d22457816c": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_335387903ecb43a586d049bed0fbc012", "style": "IPY_MODEL_1837a2056268420cb2d90e740c606892", "value": "Downloading (…)/adapter_config.json: 100%" } }, "5c843a2afc054e7ab2531cc530da226a": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "f42fd48e69eb43188397bd5a2cda9a30": { "model_name": "ProgressStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "8cb66793a8d54b1ab3ecc65157bcd6a9": { "model_name": "FloatProgressModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "bar_style": "success", "layout": "IPY_MODEL_5c843a2afc054e7ab2531cc530da226a", "max": 432, "style": "IPY_MODEL_f42fd48e69eb43188397bd5a2cda9a30", "value": 432 } }, "75fdf0dad626469ab2d138711d72cf08": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "f21e9954b48c4ee288f3782755daaf1e": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "2580697cd74a4e5d8024a7a1e7ad114c": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_75fdf0dad626469ab2d138711d72cf08", "style": "IPY_MODEL_f21e9954b48c4ee288f3782755daaf1e", "value": " 432/432 [00:00<00:00, 28.6kB/s]" } }, "d2f6e2c8392e4f18ac06dd3d0de6acf1": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "ea85d240073e4115b74c2c2f76a976de": { "model_name": "HBoxModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "children": [ "IPY_MODEL_e1ee514268c34c298c2641d22457816c", "IPY_MODEL_8cb66793a8d54b1ab3ecc65157bcd6a9", "IPY_MODEL_2580697cd74a4e5d8024a7a1e7ad114c" ], "layout": "IPY_MODEL_d2f6e2c8392e4f18ac06dd3d0de6acf1" } }, "026ed0e099e04df2864a36ca00c8a695": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "80cede5c3d1d44fc80b1e87d60c4d200": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "34ad5085cda04eff80e9b5314be98c41": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_026ed0e099e04df2864a36ca00c8a695", "style": "IPY_MODEL_80cede5c3d1d44fc80b1e87d60c4d200", "value": "Downloading adapter_model.bin: 100%" } }, "1ec286d338ec4457980391951d3cc025": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "5d0e68b8382645bcb9da308ecf025089": { "model_name": "ProgressStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "b0f2185867c94543a4491b0ecb047bc4": { "model_name": "FloatProgressModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "bar_style": "success", "layout": "IPY_MODEL_1ec286d338ec4457980391951d3cc025", "max": 7490973, "style": "IPY_MODEL_5d0e68b8382645bcb9da308ecf025089", "value": 7490973 } }, "7d5e8608d98441328ec9187e23ea34a4": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "bb5b3e305f9d409cb0916be9ec52d921": { "model_name": "DescriptionStyleModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "description_width": "" } }, "e619f8cac6b04b1dbe7075c74f8e37fc": { "model_name": "HTMLModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "layout": "IPY_MODEL_7d5e8608d98441328ec9187e23ea34a4", "style": "IPY_MODEL_bb5b3e305f9d409cb0916be9ec52d921", "value": " 7.49M/7.49M [00:00<00:00, 13.5MB/s]" } }, "c0bc4d0adacc4f75b1341c772811a50a": { "model_name": "LayoutModel", "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "state": {} }, "6bba9f6fa84a49c5b402fcc40b9c3a20": { "model_name": "HBoxModel", "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "state": { "children": [ "IPY_MODEL_34ad5085cda04eff80e9b5314be98c41", "IPY_MODEL_b0f2185867c94543a4491b0ecb047bc4", "IPY_MODEL_e619f8cac6b04b1dbe7075c74f8e37fc" ], "layout": "IPY_MODEL_c0bc4d0adacc4f75b1341c772811a50a" } } } } }, "kernelspec": { "display_name": "Python", "language": "python", "name": "python" }, "datalore": { "computation_mode": "JUPYTER", "package_manager": "pip", "base_environment": "default", "packages": [], "report_row_ids": [ "9tLzdz03okTFWD3Vy0h4vS", "h8AdYv2Y3tw7Yb4W3IGoxc", "sGTotT7rOtZWKTPgVm4gX6", "yuWdT9AFjYlOMQTPOZZIUT", "gE1CBjdhArw5iApIeceEYR", "pW8zIwKks0yGy8XVO6sU9T", "u4ni0ncMGCkPMDSQTb2Fjz", "i6bHaRjA1inmjze1Y7w5KP", "48CMjhbFyYMCFfudw3NWLS", "jCn3aXRNlHZ54hwSvCOPY1", "u9kmEPsjJjMXH8QxYxCkzI", "w9WiJH27gOJfWxpv8b8zZU", "KFRHM0usfnGQnFdbZihU7B", "Uegyfkvr1Grti6ZGTnIIEw", "trpuWhYCfTALVnH0ffydT9", "7NEFb43JKeUw4DZjeR1Qe8", "PeaXXWTbo8nx301q3J7LYT", "Tx4nI6q73aLba0qtxPM3HO", "ej3yZyRCtCDWey9uBUDZ8i", "rPZGw8diH0GN6aP7FKP8ca", "9b8k4ntwhjIkmCPW6b5sUf", "Q4UIhvaDzVEpSjVk1zzHDy", "sgKkzWyaxUFUAVeYMkmETq", "pMxR5mn4hMefEjvfv5e39e" ], "version": 3 } }, "nbformat": 4, "nbformat_minor": 4 }