TokenizerViz / app.py
Prasanna Kumar
Added validation on token ids input part
007d05b
import gradio as gr
from transformers import AutoTokenizer
import ast
from collections import Counter
import re
import plotly.graph_objs as go
import html
import random
import tiktoken
import anthropic
model_path = "models/"
# Available models
MODELS = ["Meta-Llama-3.1-8B", "gemma-2b", "gpt-3.5-turbo","gpt-4","gpt-4o"]
openai_models = ["gpt-3.5-turbo","gpt-4","gpt-4o"]
# Color palette visible on both light and dark themes
COLOR_PALETTE = [
"#e6194B", "#3cb44b", "#ffe119", "#4363d8",
"#f58231", "#911eb4", "#42d4f4", "#f032e6",
"#bfef45", "#fabed4", "#469990", "#dcbeff",
"#9A6324", "#fffac8", "#800000", "#aaffc3",
"#808000", "#ffd8b1", "#000075", "#a9a9a9"
]
def create_vertical_histogram(data, title):
labels, values = zip(*data) if data else ([], [])
fig = go.Figure(go.Bar(
x=labels,
y=values
))
fig.update_layout(
title=title,
xaxis_title="Item",
yaxis_title="Count",
height=400,
xaxis=dict(tickangle=-45)
)
return fig
def validate_input(input_type, input_value):
if input_type == "Text":
if not isinstance(input_value, str):
return False, "Input must be a string for Text input type."
elif input_type == "Token IDs":
try:
token_ids = ast.literal_eval(input_value)
if not isinstance(token_ids, list) or not all(isinstance(id, int) for id in token_ids):
return False, "Token IDs must be a list of integers."
except (ValueError, SyntaxError):
return False, "Invalid Token IDs format. Please provide a valid list of integers."
return True, ""
def process_text(text: str, model_name: str, api_key: str = None):
if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]:
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
token_ids = tokenizer.encode(text, add_special_tokens=True)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
elif model_name in openai_models:
encoding = tiktoken.encoding_for_model(model_name=model_name)
token_ids = encoding.encode(text)
tokens = [encoding.decode([id]) for id in token_ids]
elif model_name == "Claude-3-Sonnet":
if not api_key:
raise ValueError("API key is required for Claude models")
client = anthropic.Anthropic(api_key=api_key)
tokenizer = client.get_tokenizer()
token_ids = tokenizer.encode(text).ids
tokens = [tokenizer.decode([id]) for id in token_ids]
else:
raise ValueError(f"Unsupported model: {model_name}")
return text, tokens, token_ids
def process_ids(ids: str, model_name: str, api_key: str = None):
token_ids = ast.literal_eval(ids)
if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]:
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
text = tokenizer.decode(token_ids)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
elif model_name == openai_models:
encoding = tiktoken.encoding_for_model(model_name=model_name)
text = encoding.decode(token_ids)
tokens = [encoding.decode([id]) for id in token_ids]
elif model_name == "Claude-3-Sonnet":
client = anthropic.Anthropic(api_key=api_key)
tokenizer = client.get_tokenizer()
text = tokenizer.decode(token_ids)
tokens = [tokenizer.decode([id]) for id in token_ids]
else:
raise ValueError(f"Unsupported model: {model_name}")
return text, tokens, token_ids
def get_token_color(token, token_colors):
if token.startswith('<') and token.endswith('>'):
return "#42d4f4" # Cyan for special tokens
elif token == '▁' or token == ' ':
return "#3cb44b" # Green for space tokens
elif not token.isalnum():
return "#f032e6" # Magenta for special characters
else:
if token not in token_colors:
token_colors[token] = random.choice(COLOR_PALETTE)
return token_colors[token]
def create_html_tokens(tokens):
html_output = '<div style="font-family: monospace; border: 1px solid #ccc; padding: 10px; border-radius: 5px; background-color: #f9f9f9; white-space: pre-wrap; word-break: break-all;">'
token_colors = {}
for token in tokens:
color = get_token_color(token, token_colors)
escaped_token = html.escape(token)
html_output += f'<span style="background-color: {color}; color: black; padding: 2px 4px; margin: 1px; border-radius: 3px; display: inline-block;">{escaped_token}</span>'
html_output += '</div>'
return html_output
def process_input(input_type, input_value, model_name, api_key):
is_valid, error_message = validate_input(input_type, input_value)
if not is_valid:
raise gr.Error(error_message)
if input_type == "Text":
text, tokens, token_ids = process_text(text=input_value, model_name=model_name, api_key=api_key)
elif input_type == "Token IDs":
text, tokens, token_ids = process_ids(ids=input_value, model_name=model_name, api_key=api_key)
character_count = len(text)
word_count = len(text.split())
space_count = sum(1 for token in tokens if token in ['▁', ' '])
special_char_count = sum(1 for token in tokens if not token.isalnum() and token not in ['▁', ' '])
words = re.findall(r'\b\w+\b', text.lower())
special_chars = re.findall(r'[^\w\s]', text)
numbers = re.findall(r'\d+', text)
most_common_words = Counter(words).most_common(10)
most_common_special_chars = Counter(special_chars).most_common(10)
most_common_numbers = Counter(numbers).most_common(10)
words_hist = create_vertical_histogram(most_common_words, "Most Common Words")
special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters")
numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers")
analysis = f"Token count: {len(tokens)}\n"
analysis += f"Character count: {character_count}\n"
analysis += f"Word count: {word_count}\n"
analysis += f"Space tokens: {space_count}\n"
analysis += f"Special character tokens: {special_char_count}\n"
analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}"
html_tokens = create_html_tokens(tokens)
return analysis, text, html_tokens, str(token_ids), words_hist, special_chars_hist, numbers_hist
def text_example():
return "Hello, world! This is an example text input for tokenization."
def token_ids_example():
return "[128000, 9906, 11, 1917, 0, 1115, 374, 459, 3187, 1495, 1988, 369, 4037, 2065, 13]"
with gr.Blocks() as iface:
gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!")
gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.")
with gr.Row():
input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text")
model_name = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0])
# api_key = gr.Textbox(label="API Key Claude models)", type="password")
input_text = gr.Textbox(lines=5, label="Input")
with gr.Row():
text_example_button = gr.Button("Load Text Example")
token_ids_example_button = gr.Button("Load Token IDs Example")
submit_button = gr.Button("Process")
analysis_output = gr.Textbox(label="Analysis", lines=6)
text_output = gr.Textbox(label="Text", lines=6)
tokens_output = gr.HTML(label="Tokens")
token_ids_output = gr.Textbox(label="Token IDs", lines=2)
with gr.Row():
words_plot = gr.Plot(label="Most Common Words")
special_chars_plot = gr.Plot(label="Most Common Special Characters")
numbers_plot = gr.Plot(label="Most Common Numbers")
text_example_button.click(
lambda: (text_example(), "Text"),
outputs=[input_text, input_type]
)
token_ids_example_button.click(
lambda: (token_ids_example(), "Token IDs"),
outputs=[input_text, input_type]
)
submit_button.click(
process_input,
inputs=[input_type, input_text, model_name],
outputs=[analysis_output, text_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot]
)
if __name__ == "__main__":
iface.launch()