File size: 1,418 Bytes
e6b57de 5159d00 37293dc e6b57de 5159d00 553a86b 2bc1a5b 48434be 5159d00 48434be 5159d00 48434be 5159d00 31b9e0c 5159d00 2bc1a5b 48434be 31b9e0c 5159d00 553a86b e7d3e2d 3a38271 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
"""Module for tokenization utilities"""
import logging
from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + (
not text_only and colored(f"({label_id}, {input_id})", "white") or ""
)
colored_tokens.append(colored_token)
LOG.info(" ".join(colored_tokens))
LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens)
|