|
"""Module for tokenization utilities""" |
|
|
|
|
|
import logging |
|
|
|
from termcolor import colored |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): |
|
|
|
for idx in range(num_examples): |
|
check_example_labels(dataset[idx], tokenizer, text_only=text_only) |
|
|
|
|
|
def check_example_labels(example, tokenizer, text_only=False): |
|
|
|
input_ids = example["input_ids"] |
|
labels = example["labels"] |
|
|
|
|
|
|
|
colored_tokens = [] |
|
for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): |
|
decoded_input_token = tokenizer.decode(input_id) |
|
|
|
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") |
|
colored_token = colored(decoded_input_token, color) + ( |
|
not text_only and colored(f"({label_id}, {input_id})", "white") or "" |
|
) |
|
colored_tokens.append(colored_token) |
|
|
|
LOG.info(" ".join(colored_tokens)) |
|
LOG.info("\n\n\n") |
|
print(" ".join(colored_tokens)) |
|
|
|
return " ".join(colored_tokens) |
|
|