from enum import Enum from pathlib import Path import streamlit as st import streamlit.components.v1 as components import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding root_dir = Path(__file__).resolve().parent highlighted_text_component = components.declare_component( "highlighted_text", path=root_dir / "highlighted_text" / "build" ) def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding: return BatchEncoding({ k: [ t[i][j : j + window_len] + [ pad_id if k == "input_ids" else 0 ] * (j + window_len - len(t[i])) for i in range(len(examples["input_ids"])) for j in range(0, len(examples["input_ids"][i]) - 1, stride) ] for k, t in examples.items() }) BAD_CHAR = chr(0xfffd) def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False): cur_ids = [] result = [] for idx in ids: cur_ids.append(idx) decoded = tokenizer.decode(cur_ids) if BAD_CHAR not in decoded: if strip_whitespace: decoded = decoded.strip() result.append(decoded) del cur_ids[:] else: result.append("") return result st.header("Context length probing") with st.form("form"): model_name = st.selectbox("Model", ["distilgpt2", "gpt2"]) metric_name = st.selectbox("Metric", ["Cross entropy"]) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512) text = st.text_area( "Input text", "The complex houses married and single soldiers and their families.", ) st.form_submit_button("Submit") inputs = tokenizer([text]) [input_ids] = inputs["input_ids"] window_len = min(window_len, len(input_ids)) tokens = ids_to_readable_tokens(tokenizer, input_ids) inputs_sliding = get_windows_batched( inputs, window_len=window_len, pad_id=tokenizer.eos_token_id ) with torch.inference_mode(): logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16) logits = logits.permute(1, 0, 2) logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan) logits = logits.view(-1, logits.shape[-1])[:-window_len] logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1]) scores = logits.to(torch.float32).log_softmax(dim=-1) scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]] scores = scores.diff(dim=0).transpose(0, 1) scores = scores.nan_to_num() scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9 scores = scores.to(torch.float16) highlighted_text_component(tokens=tokens, scores=scores.tolist())