import io import json import re import gradio as gr import matplotlib.pyplot as plt import pandas as pd from datasets import load_dataset from PIL import Image from transformers import AutoTokenizer tokenizers = { "bert": "google-bert/bert-base-uncased", "blenderbot": "facebook/blenderbot-3B", "bloom": "bigscience/bloom-560m", "bloomz": "bigscience/bloomz-7b1", "chatglm3": "THUDM/chatglm3-6b", "falcon": "tiiuae/falcon-7b", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gpt-neox": "EleutherAI/gpt-neox-20b", "llama": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", "magicoder": "ise-uiuc/Magicoder-S-DS-6.7B", "mistral": "mistralai/Mistral-7B-v0.1", "mpt": "mosaicml/mpt-7b", "opt": "facebook/opt-2.7b", "phi-2": "microsoft/phi-2", "pythia": "EleutherAI/pythia-1.4b-deduped", "qwen": "Qwen/Qwen1.5-7B-Chat", "redpajama": "togethercomputer/RedPajama-INCITE-Chat-3B-v1", "roberta": "FacebookAI/roberta-base", "starcoder": "bigcode/starcoder2-7b", "t5": "google-t5/t5-base", "vicuna": "lmsys/vicuna-7b-v1.5", "zephyr": "HuggingFaceH4/zephyr-7b-beta", } tokenizers = list(tokenizers.values()) def plot_histogram(data): plt.hist(data) plt.title("Histogram of number of tokens per dataset item") buf = io.BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) im = Image.open(buf) return im def count(model_id, dataset_id, config, split, column, add_special_tokens=True): tokencounter = [] wordcounter = [] charcounter = [] tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) if config == "": config is None dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True) pattern = r"[a-zA-Z]+" for item in dataset: tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"] tokencounter.append(len(tokens)) charcounter.append(len(item[column])) # not 100% accurate but good enough words = re.findall(pattern, item[column]) wordcounter.append(len(words)) df = pd.DataFrame(tokencounter).describe().T df.insert(0, "type", "tokens") dfc = pd.DataFrame(charcounter).describe().T dfc.insert(0, "type", "chars") dfw = pd.DataFrame(wordcounter).describe().T dfw.insert(0, "type", "words") df.loc[-1] = dfw.values[0] df.index = df.index + 1 # shifting index df.loc[-1] = dfc.values[0] df = df.round(1) df.drop("count", axis=1, inplace=True) return plot_histogram(tokencounter), df demo = gr.Interface( fn=count, title="Dataset token counts and distribution", inputs=[ gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True), gr.Textbox(label="Dataset"), gr.Textbox(label="Config"), gr.Textbox(label="Split"), gr.Textbox(label="Column"), gr.Checkbox(label="Add special tokens", value=True), ], outputs=[ gr.Image(), gr.Dataframe(label="Token, word and character counts per dataset item"), ], examples=[ ["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"], ["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"], ["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"], ["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"], ["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"], ["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"], ["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"], ["mistralai/Mistral-7B-v0.1", "zeroshot/twitter-financial-news-sentiment", "", "validation", "text"], ], cache_examples=True, ) demo.launch()