File size: 2,928 Bytes
b253e66
8443315
 
 
 
 
 
 
 
 
b253e66
8443315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b253e66
 
 
 
 
 
 
 
8443315
b253e66
 
 
 
 
8443315
b253e66
8443315
 
 
 
 
 
 
 
 
 
 
 
 
b253e66
 
 
 
8443315
b253e66
8443315
 
 
 
 
 
b253e66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from enum import Enum
from pathlib import Path

import streamlit as st
import streamlit.components.v1 as components
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding

root_dir = Path(__file__).resolve().parent
highlighted_text_component = components.declare_component(
    "highlighted_text", path=root_dir / "highlighted_text" / "build"
)

def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding:
    return BatchEncoding({
        k: [
            t[i][j : j + window_len] + [
                pad_id if k == "input_ids" else 0
            ] * (j + window_len - len(t[i]))
            for i in range(len(examples["input_ids"]))
            for j in range(0, len(examples["input_ids"][i]) - 1, stride)
        ]
        for k, t in examples.items()
    })

BAD_CHAR = chr(0xfffd)

def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
    cur_ids = []
    result = []
    for idx in ids:
        cur_ids.append(idx)
        decoded = tokenizer.decode(cur_ids)
        if BAD_CHAR not in decoded:
            if strip_whitespace:
                decoded = decoded.strip()
            result.append(decoded)
            del cur_ids[:]
        else:
            result.append("")
    return result

st.header("Context length probing")

with st.form("form"):
    model_name = st.selectbox("Model", ["distilgpt2", "gpt2"])
    metric_name = st.selectbox("Metric", ["Cross entropy"])

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512)
    text = st.text_area(
        "Input text",
        "The complex houses married and single soldiers and their families.",
    )

    st.form_submit_button("Submit")

inputs = tokenizer([text])
[input_ids] = inputs["input_ids"]
window_len = min(window_len, len(input_ids))
tokens = ids_to_readable_tokens(tokenizer, input_ids)

inputs_sliding = get_windows_batched(
    inputs,
    window_len=window_len,
    pad_id=tokenizer.eos_token_id
)
with torch.inference_mode():
    logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
    logits = logits.permute(1, 0, 2)
    logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
    logits = logits.view(-1, logits.shape[-1])[:-window_len]
    logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])

scores = logits.to(torch.float32).log_softmax(dim=-1)
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
scores = scores.diff(dim=0).transpose(0, 1)
scores = scores.nan_to_num()
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
scores = scores.to(torch.float16)

highlighted_text_component(tokens=tokens, scores=scores.tolist())