Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer | |
import ast | |
from collections import Counter | |
import re | |
import plotly.graph_objs as go | |
import html | |
import random | |
import tiktoken | |
import anthropic | |
model_path = "models/" | |
# Available models | |
MODELS = ["Meta-Llama-3.1-8B", "gemma-2b", "gpt-3.5-turbo","gpt-4","gpt-4o"] | |
openai_models = ["gpt-3.5-turbo","gpt-4","gpt-4o"] | |
# Color palette visible on both light and dark themes | |
COLOR_PALETTE = [ | |
"#e6194B", "#3cb44b", "#ffe119", "#4363d8", | |
"#f58231", "#911eb4", "#42d4f4", "#f032e6", | |
"#bfef45", "#fabed4", "#469990", "#dcbeff", | |
"#9A6324", "#fffac8", "#800000", "#aaffc3", | |
"#808000", "#ffd8b1", "#000075", "#a9a9a9" | |
] | |
def create_vertical_histogram(data, title): | |
labels, values = zip(*data) if data else ([], []) | |
fig = go.Figure(go.Bar( | |
x=labels, | |
y=values | |
)) | |
fig.update_layout( | |
title=title, | |
xaxis_title="Item", | |
yaxis_title="Count", | |
height=400, | |
xaxis=dict(tickangle=-45) | |
) | |
return fig | |
def validate_input(input_type, input_value): | |
if input_type == "Text": | |
if not isinstance(input_value, str): | |
return False, "Input must be a string for Text input type." | |
elif input_type == "Token IDs": | |
try: | |
token_ids = ast.literal_eval(input_value) | |
if not isinstance(token_ids, list) or not all(isinstance(id, int) for id in token_ids): | |
return False, "Token IDs must be a list of integers." | |
except (ValueError, SyntaxError): | |
return False, "Invalid Token IDs format. Please provide a valid list of integers." | |
return True, "" | |
def process_text(text: str, model_name: str, api_key: str = None): | |
if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]: | |
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) | |
token_ids = tokenizer.encode(text, add_special_tokens=True) | |
tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
elif model_name in openai_models: | |
encoding = tiktoken.encoding_for_model(model_name=model_name) | |
token_ids = encoding.encode(text) | |
tokens = [encoding.decode([id]) for id in token_ids] | |
elif model_name == "Claude-3-Sonnet": | |
if not api_key: | |
raise ValueError("API key is required for Claude models") | |
client = anthropic.Anthropic(api_key=api_key) | |
tokenizer = client.get_tokenizer() | |
token_ids = tokenizer.encode(text).ids | |
tokens = [tokenizer.decode([id]) for id in token_ids] | |
else: | |
raise ValueError(f"Unsupported model: {model_name}") | |
return text, tokens, token_ids | |
def process_ids(ids: str, model_name: str, api_key: str = None): | |
token_ids = ast.literal_eval(ids) | |
if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]: | |
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) | |
text = tokenizer.decode(token_ids) | |
tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
elif model_name == openai_models: | |
encoding = tiktoken.encoding_for_model(model_name=model_name) | |
text = encoding.decode(token_ids) | |
tokens = [encoding.decode([id]) for id in token_ids] | |
elif model_name == "Claude-3-Sonnet": | |
client = anthropic.Anthropic(api_key=api_key) | |
tokenizer = client.get_tokenizer() | |
text = tokenizer.decode(token_ids) | |
tokens = [tokenizer.decode([id]) for id in token_ids] | |
else: | |
raise ValueError(f"Unsupported model: {model_name}") | |
return text, tokens, token_ids | |
def get_token_color(token, token_colors): | |
if token.startswith('<') and token.endswith('>'): | |
return "#42d4f4" # Cyan for special tokens | |
elif token == 'β' or token == ' ': | |
return "#3cb44b" # Green for space tokens | |
elif not token.isalnum(): | |
return "#f032e6" # Magenta for special characters | |
else: | |
if token not in token_colors: | |
token_colors[token] = random.choice(COLOR_PALETTE) | |
return token_colors[token] | |
def create_html_tokens(tokens): | |
html_output = '<div style="font-family: monospace; border: 1px solid #ccc; padding: 10px; border-radius: 5px; background-color: #f9f9f9; white-space: pre-wrap; word-break: break-all;">' | |
token_colors = {} | |
for token in tokens: | |
color = get_token_color(token, token_colors) | |
escaped_token = html.escape(token) | |
html_output += f'<span style="background-color: {color}; color: black; padding: 2px 4px; margin: 1px; border-radius: 3px; display: inline-block;">{escaped_token}</span>' | |
html_output += '</div>' | |
return html_output | |
def process_input(input_type, input_value, model_name, api_key): | |
is_valid, error_message = validate_input(input_type, input_value) | |
if not is_valid: | |
raise gr.Error(error_message) | |
if input_type == "Text": | |
text, tokens, token_ids = process_text(text=input_value, model_name=model_name, api_key=api_key) | |
elif input_type == "Token IDs": | |
text, tokens, token_ids = process_ids(ids=input_value, model_name=model_name, api_key=api_key) | |
character_count = len(text) | |
word_count = len(text.split()) | |
space_count = sum(1 for token in tokens if token in ['β', ' ']) | |
special_char_count = sum(1 for token in tokens if not token.isalnum() and token not in ['β', ' ']) | |
words = re.findall(r'\b\w+\b', text.lower()) | |
special_chars = re.findall(r'[^\w\s]', text) | |
numbers = re.findall(r'\d+', text) | |
most_common_words = Counter(words).most_common(10) | |
most_common_special_chars = Counter(special_chars).most_common(10) | |
most_common_numbers = Counter(numbers).most_common(10) | |
words_hist = create_vertical_histogram(most_common_words, "Most Common Words") | |
special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters") | |
numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers") | |
analysis = f"Token count: {len(tokens)}\n" | |
analysis += f"Character count: {character_count}\n" | |
analysis += f"Word count: {word_count}\n" | |
analysis += f"Space tokens: {space_count}\n" | |
analysis += f"Special character tokens: {special_char_count}\n" | |
analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}" | |
html_tokens = create_html_tokens(tokens) | |
return analysis, text, html_tokens, str(token_ids), words_hist, special_chars_hist, numbers_hist | |
def text_example(): | |
return "Hello, world! This is an example text input for tokenization." | |
def token_ids_example(): | |
return "[128000, 9906, 11, 1917, 0, 1115, 374, 459, 3187, 1495, 1988, 369, 4037, 2065, 13]" | |
with gr.Blocks() as iface: | |
gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!") | |
gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.") | |
with gr.Row(): | |
input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text") | |
model_name = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]) | |
# api_key = gr.Textbox(label="API Key Claude models)", type="password") | |
input_text = gr.Textbox(lines=5, label="Input") | |
with gr.Row(): | |
text_example_button = gr.Button("Load Text Example") | |
token_ids_example_button = gr.Button("Load Token IDs Example") | |
submit_button = gr.Button("Process") | |
analysis_output = gr.Textbox(label="Analysis", lines=6) | |
text_output = gr.Textbox(label="Text", lines=6) | |
tokens_output = gr.HTML(label="Tokens") | |
token_ids_output = gr.Textbox(label="Token IDs", lines=2) | |
with gr.Row(): | |
words_plot = gr.Plot(label="Most Common Words") | |
special_chars_plot = gr.Plot(label="Most Common Special Characters") | |
numbers_plot = gr.Plot(label="Most Common Numbers") | |
text_example_button.click( | |
lambda: (text_example(), "Text"), | |
outputs=[input_text, input_type] | |
) | |
token_ids_example_button.click( | |
lambda: (token_ids_example(), "Token IDs"), | |
outputs=[input_text, input_type] | |
) | |
submit_button.click( | |
process_input, | |
inputs=[input_type, input_text, model_name], | |
outputs=[analysis_output, text_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot] | |
) | |
if __name__ == "__main__": | |
iface.launch() |