yenniejun commited on
Commit
535e678
1 Parent(s): 29daff8

clean up the figure, add data caching, add headers

Browse files
Files changed (1) hide show
  1. app.py +35 -18
app.py CHANGED
@@ -10,6 +10,11 @@ import numpy as np
10
  import plotly.figure_factory as ff
11
  import plotly.express as px
12
 
 
 
 
 
 
13
  tokenizer_names_to_test = [
14
  "openai/gpt4",
15
  "xlm-roberta-base", # old style
@@ -24,27 +29,30 @@ tokenizer_names_to_test = [
24
  ]
25
 
26
  with st.sidebar:
 
 
 
 
 
27
  with st.spinner('Loading dataset...'):
28
- val_data = pd.read_csv('MassiveDatasetValidationData.csv')
29
  st.success(f'Data loaded: {len(val_data)}')
30
 
31
  languages = st.multiselect(
32
  'Select languages',
33
  options=sorted(val_data.lang.unique()),
34
  default=['English', 'Spanish' ,'Chinese'],
35
- max_selections=5
36
  )
37
-
38
- # TODO multi-select tokenizers
39
- # TODO add openai to this options
40
- tokenizer_name = st.sidebar.selectbox('Tokenizers', options=tokenizer_names_to_test)
41
- st.write('You selected:', tokenizer_name)
42
 
43
  # with st.spinner('Loading tokenizer...'):
44
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
45
  # st.success(f'Tokenizer loaded: {tokenizer_name}')
46
 
47
- # # TODO - preload the tokenized versions ... much easier!
48
  # # TODO - add the metadata data as well??? later on maybe
49
  # with st.spinner('Calculating tokenization for data...'):
50
  # if tokenizer_name not in val_data.columns:
@@ -55,18 +63,27 @@ with st.container():
55
  if tokenizer_name in val_data.columns:
56
  subset_df = val_data[val_data.lang.isin(languages)]
57
  subset_data = [val_data[val_data.lang==_lang][tokenizer_name] for _lang in languages]
 
 
 
 
58
 
59
-
60
- fig = ff.create_distplot(subset_data, group_labels=languages, show_hist=False)
 
 
 
 
 
61
  st.plotly_chart(fig, use_container_width=True)
62
 
 
 
 
 
63
 
64
- # for _lang in languages:
65
- # subset = val_data[val_data.lang==_lang]
66
-
67
- # fig = ff.create_distplot(val_data, bin_size=.5,
68
- # curve_type='normal', # override default 'kde'
69
- # colors=colors)
70
-
71
-
72
 
 
10
  import plotly.figure_factory as ff
11
  import plotly.express as px
12
 
13
+ @st.cache_data
14
+ def load_data():
15
+ return pd.read_csv('MassiveDatasetValidationData.csv')
16
+
17
+ # TODO allow new tokenizers from HF
18
  tokenizer_names_to_test = [
19
  "openai/gpt4",
20
  "xlm-roberta-base", # old style
 
29
  ]
30
 
31
  with st.sidebar:
32
+ st.subheader('Model')
33
+ # TODO multi-select tokenizers
34
+ tokenizer_name = st.sidebar.selectbox('Select tokenizer', options=tokenizer_names_to_test)
35
+
36
+ st.subheader('Data')
37
  with st.spinner('Loading dataset...'):
38
+ val_data = load_data()
39
  st.success(f'Data loaded: {len(val_data)}')
40
 
41
  languages = st.multiselect(
42
  'Select languages',
43
  options=sorted(val_data.lang.unique()),
44
  default=['English', 'Spanish' ,'Chinese'],
45
+ max_selections=6
46
  )
47
+
48
+ st.subheader('Figure')
49
+ show_hist = st.checkbox('Show histogram', value=False)
50
+ # dist_marginal = st.radio('Select distribution', options=['box', 'violin', 'rug'], horizontal=True)
 
51
 
52
  # with st.spinner('Loading tokenizer...'):
53
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
54
  # st.success(f'Tokenizer loaded: {tokenizer_name}')
55
 
 
56
  # # TODO - add the metadata data as well??? later on maybe
57
  # with st.spinner('Calculating tokenization for data...'):
58
  # if tokenizer_name not in val_data.columns:
 
63
  if tokenizer_name in val_data.columns:
64
  subset_df = val_data[val_data.lang.isin(languages)]
65
  subset_data = [val_data[val_data.lang==_lang][tokenizer_name] for _lang in languages]
66
+
67
+ st.header('Tokenization in different languages')
68
+ st.divider()
69
+ fig = ff.create_distplot(subset_data, group_labels=languages, show_hist=show_hist)
70
 
71
+ fig.update_layout(
72
+ title=dict(text=tokenizer_name, font=dict(size=25), automargin=True, yref='paper', ),
73
+ # title=tokenizer_name,
74
+ xaxis_title="Number of Tokens",
75
+ yaxis_title="Density",
76
+ # title_font_family='"Source Sans Pro", sans-serif'
77
+ )
78
  st.plotly_chart(fig, use_container_width=True)
79
 
80
+ st.subheader('Median Token Length')
81
+ metric_cols = st.columns(len(languages))
82
+ for i, _lang in enumerate(languages):
83
+ metric_cols[i].metric(_lang, int(np.median(subset_df[subset_df.lang==_lang][tokenizer_name])))
84
 
85
+ if tokenizer_name not in ['openai/gpt4']:
86
+ url = f'https://huggingface.co/{tokenizer_name}'
87
+ link = f'[Find on the HuggingFace hub]({url})'
88
+ st.markdown(link, unsafe_allow_html=True)
 
 
 
 
89