helenai's picture
Initial commit
57b690d
raw
history blame
3.15 kB
import io
import json
import re
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer
tokenizers = [
"google/gemma-7b",
"meta-llama/Llama-2-7b",
"mistralai/Mistral-7B-v0.1",
"facebook/opt-2.7b",
"microsoft/phi-2",
"THUDM/chatglm3-6b",
"Qwen/Qwen1.5-7B-Chat",
"bigscience/bloom-560m",
"ise-uiuc/Magicoder-S-DS-6.7B",
"google/flan-t5-base",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
def plot_histogram(data):
plt.hist(data)
plt.title("Histogram of number of tokens per dataset item")
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
im = Image.open(buf)
return im
def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
tokencounter = []
wordcounter = []
charcounter = []
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if config == "":
config is None
dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
pattern = r"[a-zA-Z]+"
for item in dataset:
tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
tokencounter.append(len(tokens))
charcounter.append(len(item[column]))
# not 100% accurate but good enough
words = re.findall(pattern, item[column])
wordcounter.append(len(words))
df = pd.DataFrame(tokencounter).describe().T
df.insert(0, "type", "tokens")
dfc = pd.DataFrame(charcounter).describe().T
dfc.insert(0, "type", "chars")
dfw = pd.DataFrame(wordcounter).describe().T
dfw.insert(0, "type", "words")
df.loc[-1] = dfw.values[0]
df.index = df.index + 1 # shifting index
df.loc[-1] = dfc.values[0]
df = df.round(1)
df.drop("count", axis=1, inplace=True)
return plot_histogram(tokencounter), df
demo = gr.Interface(
fn=count,
title="Dataset token counts and distribution",
inputs=[
gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
gr.Textbox(label="Dataset"),
gr.Textbox(label="Config"),
gr.Textbox(label="Split"),
gr.Textbox(label="Column"),
gr.Checkbox(label="Add special tokens", value=True),
],
outputs=[
gr.Image(),
gr.Dataframe(label="Token, word and character counts per dataset item"),
],
examples=[
["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"],
["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"],
["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"],
["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"],
["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"],
],
cache_examples=False
)
demo.launch()