## Fine-tuning RoBERTa large for token classification

Treats fixing commas as a NER problem, where for each token we predict whether a comma should be inserted after it. We assume input data has no commas, which ensures the input distribution is the same for the model, regardless of the types of mistakes users could make. The model would then restore the commas and leave the rest of the text intact.

In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
from datasets import load_dataset
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
import seqeval
import torch
import numpy as np
import re
import evaluate

In [3]:
model_checkpoint = "roberta-large"

We will use the wikitext dataset, since it is large and has more diverse texts than, e.g., books, with fairly a lot of commas.

In [4]:
wikitext = load_dataset('wikitext', 'wikitext-103-v1') # TODO we should only load part of it, too big to train on whole anyway

In [5]:
wikitext

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

### Preprocessing

In [6]:
label_list = [
    "O",
    "B-COMMA",
]
id2label = {
    0: "O",
    1: "B-COMMA"
}
label2id = {
    "O": 0,
    "B-COMMA": 1
}

Wikitext is already space tokenized. We use that information, remove commas from the data and append a COMMA tag to the preceding token.

In [7]:
def map_wikitext(x) -> dict:
  tokens = x["text"].split()
  new_tokens, labels = [], []
  for token in tokens:
    if ',' in token:
      if not labels:
          print(x["text"])
      else:
        labels[-1] = label2id["B-COMMA"]
    else:
      labels.append(label2id["O"])
      new_tokens.append(token)
  return {'tokens': new_tokens, 'tags': labels}

In [8]:
wikitext["train"][3]

{'text': ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n'}

Other than mapping, we also filter empty texts (25% in wikitext), and very long paragraphs. We print texts starting with a comma, and remove the initial comma since we cannot represent it and assume no sentene should start with a comma.

In [9]:
wikitext_mapped = wikitext.filter(lambda x: x["text"] and len(x["text"].split()) < 512).map(map_wikitext)

 , 

 , the slight increase in comparison loop efficiency does not compensate for the extra iteration . Knuth 1998 gives a value of 



In [10]:
wikitext_mapped["train"][1]

{'text': ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n',
 'tokens': ['Senjō',
  'no',
  'Valkyria',
  '3',
  ':',
  '<unk>',
  'Chronicles',
  '(',
  'Japanese',
  ':',
  '戦場のヴァルキュリア3',
  'lit',
  '.',
  'Valkyria',
  'of',
  'the',
  'Battlefield',
  '3',
  ')',
  'commonly',
  'referred',
  'to',
  'as',
  'Valkyria',
  'Chronicles',
  'III'

In [11]:
seqeval = evaluate.load("seqeval")

In [12]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [13]:
tokenizer = AutoTokenizer.from_pretrained('roberta-large', add_prefix_space=True)
tokenizer

RobertaTokenizerFast(name_or_path='roberta-large', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)

We need to map the space-tokenized wikitext to the roberta tokenization, together with the token tags. -100 is ignored by PyTorch during gradient computation, and is commonly used for special tokens (<CLS> and such) and additional tokens that appear in the middle of words due to wordpiece tokenization.

In [14]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [15]:
tokenized_wikitext = wikitext_mapped.map(tokenize_and_align_labels, batched=True)

In [16]:
tokenized_wikitext = tokenized_wikitext.remove_columns('text')

In [17]:
for input_id, label in zip(tokenized_wikitext["train"][1]['input_ids'], tokenized_wikitext["train"][1]['labels']):
  print(tokenizer.convert_ids_to_tokens(input_id), id2label[label])

<s> -100
ĠSen 0
j -100
Åį -100
Ġno 0
ĠV 0
alky -100
ria -100
Ġ3 0
Ġ: 0
<unk> 0
ĠChronicles 0
Ġ( 0
ĠJapanese 0
Ġ: 0
Ġæ 1
Ī -100
¦ -100
å -100
ł -100
´ -100
ãģ® -100
ãĥ´ãĤ¡ -100
ãĥ« -100
ãĤŃ -100
ãĥ¥ -100
ãĥª -100
ãĤ¢ -100
3 -100
Ġlit 0
Ġ. 0
ĠV 0
alky -100
ria -100
Ġof 0
Ġthe 0
ĠBattlefield 0
Ġ3 0
Ġ) 1
Ġcommonly 0
Ġreferred 0
Ġto 0
Ġas 0
ĠV 0
alky -100
ria -100
ĠChronicles 0
ĠIII 0
Ġoutside 0
ĠJapan 1
Ġis 0
Ġa 0
Ġtactical 0
Ġrole 0
Ġ@ 0
- -100
@ -100
Ġplaying 0
Ġvideo 0
Ġgame 0
Ġdeveloped 0
Ġby 0
ĠSega 0
Ġand 0
ĠMedia 0
. -100
Vision -100
Ġfor 0
Ġthe 0
ĠPlayStation 0
ĠPortable 0
Ġ. 0
ĠReleased 0
Ġin 0
ĠJanuary 0
Ġ2011 0
Ġin 0
ĠJapan 1
Ġit 0
Ġis 0
Ġthe 0
Ġthird 0
Ġgame 0
Ġin 0
Ġthe 0
ĠV 0
alky -100
ria -100
Ġseries 0
Ġ. 0
ĠEmploy 0
ing -100
Ġthe 0
Ġsame 0
Ġfusion 0
Ġof 0
Ġtactical 0
Ġand 0
Ġreal 0
Ġ@ 0
- -100
@ -100
Ġtime 0
Ġgameplay 0
Ġas 0
Ġits 0
Ġpredecessors 1
Ġthe 0
Ġstory 0
Ġruns 0
Ġparallel 0
Ġto 0
Ġthe 0
Ġfirst 0
Ġgame 0
Ġand 0
Ġfollows 0
Ġthe 0
Ġ" 0
ĠNam 0
eless -100
Ġ" 1
Ġa 0
Ġp

The collator automatically handles padding the tokens and labels inside batches

In [18]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

### Training

In [19]:

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint, num_labels=len(label_list), id2label=id2label, label2id=label2id
)

Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
peft_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias="all"
)

In [21]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 1,848,324 || all params: 355,887,108 || trainable%: 0.519356829301049


In [22]:
lr = 1e-3
batch_size = 8

In [23]:
training_args = TrainingArguments(
    output_dir="roberta-large-lora-token-classification",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,
    warmup_steps=200,
    max_steps=20000,
    logging_steps = 10,
    save_steps=100,
    save_total_limit=3,
    weight_decay=0.01,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    load_best_model_at_end=True,
)

In [24]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_wikitext["train"],
    eval_dataset=tokenized_wikitext["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


KeyboardInterrupt: KeyboardInterrupt: 

Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
100,0.0824,0.071184,0.738182,0.7532,0.745615,0.973547
200,0.0622,0.051519,0.8037,0.850525,0.82645,0.981614
300,0.0491,0.044637,0.821739,0.858548,0.83974,0.983133
400,0.046,0.043286,0.827163,0.855683,0.841181,0.983369
500,0.0497,0.0434,0.815975,0.873257,0.843645,0.98334
600,0.0479,0.043947,0.790265,0.908691,0.845351,0.982887
700,0.0425,0.040706,0.846508,0.84384,0.845171,0.984087
800,0.0455,0.040963,0.845999,0.845272,0.845636,0.984116
900,0.05,0.042453,0.852265,0.831996,0.842009,0.983929
1000,0.0508,0.042836,0.860358,0.803247,0.830822,0.983163


### Saving and evaluating the model

In [26]:
trainer.evaluate(tokenized_wikitext["test"])

{'eval_loss': 0.037630438804626465,
 'eval_precision': 0.8471585502984171,
 'eval_recall': 0.8514300617230288,
 'eval_f1': 0.8492889351370101,
 'eval_accuracy': 0.9848677451373048}

In [27]:
from huggingface_hub import notebook_login

notebook_login()

In [28]:
hub_name = "klasocki/roberta-large-lora-ner-comma-fixer"

In [29]:
model.push_to_hub(hub_name)

CommitInfo(commit_url='https://huggingface.co/klasocki/roberta-large-lora-ner-comma-fixer/commit/b6e99b176b6814a75e841edcfaa8fef649feaf31', commit_message='Upload model', commit_description='', oid='b6e99b176b6814a75e841edcfaa8fef649feaf31', pr_url=None, pr_revision=None, pr_num=None)

### Inference

In [30]:
peft_model_id = hub_name
config = PeftConfig.from_pretrained(peft_model_id)
inference_model = AutoModelForTokenClassification.from_pretrained(
    config.base_model_name_or_path, num_labels=len(label_list), id2label=id2label, label2id=label2id
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(inference_model, peft_model_id)

Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [34]:
text = "This text should have commas here here and there however it does not."
inputs = tokenizer(text, return_tensors="pt")

In [35]:
with torch.no_grad():
    logits = model(**inputs).logits

tokens = inputs.tokens()
predictions = torch.argmax(logits, dim=2)

for token, prediction in zip(tokens, predictions[0].numpy()):
    print((token, model.config.id2label[prediction]))

('<s>', 'O')
('This', 'O')
('Ġtext', 'O')
('Ġshould', 'O')
('Ġhave', 'O')
('Ġcomm', 'O')
('as', 'O')
('Ġhere', 'B-COMMA')
('Ġhere', 'O')
('Ġand', 'O')
('Ġthere', 'B-COMMA')
('Ġhowever', 'O')
('Ġit', 'O')
('Ġdoes', 'O')
('Ġnot', 'O')
('.', 'O')
('</s>', 'O')
