|
import io |
|
import json |
|
import re |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from PIL import Image |
|
from transformers import AutoTokenizer |
|
|
|
|
|
tokenizers = { |
|
"bert": "google-bert/bert-base-uncased", |
|
"blenderbot": "facebook/blenderbot-3B", |
|
"bloom": "bigscience/bloom-560m", |
|
"bloomz": "bigscience/bloomz-7b1", |
|
"chatglm3": "THUDM/chatglm3-6b", |
|
"falcon": "tiiuae/falcon-7b", |
|
"gemma": "fxmarty/tiny-random-GemmaForCausalLM", |
|
"gpt-neox": "EleutherAI/gpt-neox-20b", |
|
"llama": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"magicoder": "ise-uiuc/Magicoder-S-DS-6.7B", |
|
"mistral": "mistralai/Mistral-7B-v0.1", |
|
"mpt": "mosaicml/mpt-7b", |
|
"opt": "facebook/opt-2.7b", |
|
"phi-2": "microsoft/phi-2", |
|
"pythia": "EleutherAI/pythia-1.4b-deduped", |
|
"qwen": "Qwen/Qwen1.5-7B-Chat", |
|
"redpajama": "togethercomputer/RedPajama-INCITE-Chat-3B-v1", |
|
"roberta": "FacebookAI/roberta-base", |
|
"starcoder": "bigcode/starcoder2-7b", |
|
"t5": "google-t5/t5-base", |
|
"vicuna": "lmsys/vicuna-7b-v1.5", |
|
"zephyr": "HuggingFaceH4/zephyr-7b-beta", |
|
} |
|
|
|
tokenizers = list(tokenizers.values()) |
|
|
|
|
|
def plot_histogram(data): |
|
plt.hist(data) |
|
plt.title("Histogram of number of tokens per dataset item") |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
plt.close() |
|
buf.seek(0) |
|
im = Image.open(buf) |
|
return im |
|
|
|
|
|
def count(model_id, dataset_id, config, split, column, add_special_tokens=True): |
|
tokencounter = [] |
|
wordcounter = [] |
|
charcounter = [] |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
if config == "": |
|
config is None |
|
dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True) |
|
pattern = r"[a-zA-Z]+" |
|
for item in dataset: |
|
tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"] |
|
tokencounter.append(len(tokens)) |
|
charcounter.append(len(item[column])) |
|
|
|
words = re.findall(pattern, item[column]) |
|
wordcounter.append(len(words)) |
|
|
|
df = pd.DataFrame(tokencounter).describe().T |
|
df.insert(0, "type", "tokens") |
|
dfc = pd.DataFrame(charcounter).describe().T |
|
dfc.insert(0, "type", "chars") |
|
dfw = pd.DataFrame(wordcounter).describe().T |
|
dfw.insert(0, "type", "words") |
|
df.loc[-1] = dfw.values[0] |
|
df.index = df.index + 1 |
|
df.loc[-1] = dfc.values[0] |
|
df = df.round(1) |
|
df.drop("count", axis=1, inplace=True) |
|
|
|
return plot_histogram(tokencounter), df |
|
|
|
|
|
demo = gr.Interface( |
|
fn=count, |
|
title="Dataset token counts and distribution", |
|
inputs=[ |
|
gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True), |
|
gr.Textbox(label="Dataset"), |
|
gr.Textbox(label="Config"), |
|
gr.Textbox(label="Split"), |
|
gr.Textbox(label="Column"), |
|
gr.Checkbox(label="Add special tokens", value=True), |
|
], |
|
outputs=[ |
|
gr.Image(), |
|
gr.Dataframe(label="Token, word and character counts per dataset item"), |
|
], |
|
examples=[ |
|
["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"], |
|
["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"], |
|
["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"], |
|
["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"], |
|
["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"], |
|
["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"], |
|
["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"], |
|
["mistralai/Mistral-7B-v0.1", "zeroshot/twitter-financial-news-sentiment", "", "validation", "text"], |
|
], |
|
cache_examples=True, |
|
) |
|
|
|
demo.launch() |
|
|