File size: 2,005 Bytes
f38dbc0
b60285f
 
 
 
 
 
 
 
f38dbc0
 
b60285f
 
 
 
f38dbc0
b60285f
 
 
 
 
f38dbc0
b60285f
b83ecd0
b60285f
 
 
 
5008e92
 
 
 
 
 
 
 
 
 
 
 
b60285f
 
 
f38dbc0
 
 
b60285f
 
 
 
 
 
 
 
 
 
f38dbc0
 
 
 
b60285f
 
 
 
 
 
 
 
f38dbc0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import io

import pandas as pd
import streamlit as st

from transformers import AutoTokenizer

from tapas_visualizer import TapasVisualizer

st.set_page_config(page_title="Tapas Tokenizer", page_icon="‍🍽️", layout="wide")


def set_file_input():
    st.session_state.input_stream = "file"


def set_text_input():
    st.session_state.input_stream = "text"


def main():
    models = ["google/tapas-base", "deepset/tapas-large-nq-hn-reader"]

    @st.cache_resource()
    def load_tokenizer():
        tokenizer = AutoTokenizer.from_pretrained(selected_model)
        return tokenizer

    st.markdown(
        """
        ## TAPAS Tokenization Visualization

        [TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas) models work on Tables.
        The tool below is to help visualize how the table is tokenized and give total (+ row-wise)
        token counts.

        Implementation adapted from `tokenizers.tools.EncodingVisualizer`.
        """
    )

    col1, col2 = st.columns([1, 2])
    with col1:
        selected_model = st.selectbox("Select a tokenizer", models, key=1)
        text = st.text_area(
            label="", placeholder="Table to tokenize; csv", on_change=set_text_input
        )
        uploaded_file = st.file_uploader("(Or) Choose a file", on_change=set_file_input)
        button_clicked = st.button("Tokenize")

    tokenizer = load_tokenizer()
    visualizer = TapasVisualizer(tokenizer)

    with col2:
        if text or uploaded_file or button_clicked:
            df: pd.DataFrame

            if (
                "input_stream" not in st.session_state
                or st.session_state.input_stream == "text"
            ):
                df = pd.read_csv(io.StringIO(text), sep=",")
            elif st.session_state.input_stream == "file":
                df = pd.read_csv(uploaded_file)

            if df is not None:
                st.components.v1.html(visualizer(df.astype(str)), height=1500)


if __name__ == "__main__":
    main()