import gradio as gr
from transformers import T5TokenizerFast, CLIPTokenizer

# Load the common tokenizers once
t5_tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=False)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

def count_tokens(text):
    # Get tokens and their IDs
    t5_tokens = t5_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)[0].tolist()
    clip_tokens = clip_tokenizer.encode(text, add_special_tokens=True)
    
    # Decode individual tokens for display, explicitly setting skip_special_tokens=False
    t5_decoded = []
    for token in t5_tokens:
        decoded = t5_tokenizer.decode([token], skip_special_tokens=False)
        if decoded.isspace():
            decoded = "␣"
        elif decoded == "":
            # Handle special tokens explicitly for T5
            if token == 3:
                decoded = "▁"  # Represent token ID 3 as ▁
            else:
                decoded = "∅"  # Default for other empty tokens
        t5_decoded.append(decoded)
    
    clip_decoded = []
    for token in clip_tokens:
        decoded = clip_tokenizer.decode([token], skip_special_tokens=False)
        if decoded.isspace():
            decoded = "␣"
        elif decoded == "":
            decoded = "∅"
        clip_decoded.append(decoded)
    
    # Create highlighted text tuples with empty labels
    t5_highlights = [(token, "") for token in t5_decoded]
    clip_highlights = [(token, "") for token in clip_decoded]
    
    return (
        # T5 outputs
        len(t5_tokens),
        t5_highlights,
        str(t5_tokens),
        # CLIP outputs
        len(clip_tokens),
        clip_highlights,
        str(clip_tokens)
    )

# Create a Gradio interface with custom layout
with gr.Blocks(title="DiffusionTokenizer") as iface:
    gr.Markdown("# DiffusionTokenizer🔢")
    gr.Markdown("A lightning fast visualization of the tokens used in diffusion models. Use it to understand how your prompt is tokenized.")
    
    with gr.Row():
        text_input = gr.Textbox(label="Diffusion Prompt", placeholder="Enter your prompt here...")
    
    with gr.Row():
        # T5 Column
        with gr.Column():
            gr.Markdown("### T5 Tokenizer Results")
            t5_count = gr.Number(label="T5 Token Count")
            t5_highlights = gr.HighlightedText(label="T5 Tokens", show_legend=True)
            t5_ids = gr.Textbox(label="T5 Token IDs", lines=2)
        
        # CLIP Column
        with gr.Column():
            gr.Markdown("### CLIP Tokenizer Results")
            clip_count = gr.Number(label="CLIP Token Count")
            clip_highlights = gr.HighlightedText(label="CLIP Tokens", show_legend=True)
            clip_ids = gr.Textbox(label="CLIP Token IDs", lines=2)
    
    text_input.change(
        fn=count_tokens,
        inputs=[text_input],
        outputs=[t5_count, t5_highlights, t5_ids, clip_count, clip_highlights, clip_ids]
    )

# Launch the app
iface.launch(show_error=True, ssr_mode = False)