edugp commited on
Commit
3c30fa3
1 Parent(s): 7b62017

Sync with data tooling repo, using edugp/kenlm models, updating viz to use quantiles for coloring and ad-hoc viz for the registry dataset

Browse files
README.md CHANGED
@@ -11,7 +11,11 @@ pinned: false
11
  # Installation:
12
  Requires Python >= 3.7 and < 3.10
13
  ```
14
- pip install -r requirements.txt
 
 
 
 
15
  ```
16
 
17
  # Web App:
 
11
  # Installation:
12
  Requires Python >= 3.7 and < 3.10
13
  ```
14
+ pip install .
15
+ ```
16
+ Or with [poetry](https://python-poetry.org/)
17
+ ```
18
+ poetry install
19
  ```
20
 
21
  # Web App:
app.py CHANGED
@@ -3,16 +3,28 @@ from functools import partial
3
 
4
  import streamlit as st
5
  from embedding_lenses.data import uploaded_file_to_dataframe
6
- from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings,
7
- get_umap_embeddings)
 
 
8
  from embedding_lenses.embedding import load_model
9
 
10
- from perplexity_lenses.data import (documents_df_to_sentences_df,
11
- hub_dataset_to_dataframe)
12
- from perplexity_lenses.engine import (DIMENSIONALITY_REDUCTION_ALGORITHMS,
13
- DOCUMENT_TYPES, EMBEDDING_MODELS,
14
- LANGUAGES, SEED, generate_plot)
 
 
 
 
 
 
 
 
 
15
  from perplexity_lenses.perplexity import KenlmModel
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
@@ -43,12 +55,36 @@ with col6:
43
  doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
44
  with col7:
45
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
 
 
 
46
 
47
  dimensionality_reduction = st.selectbox(
48
  "Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
49
  )
50
  model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with st.spinner(text="Loading embedding model..."):
53
  model = load_model(model_name)
54
  dimensionality_reduction_function = (
@@ -58,7 +94,14 @@ dimensionality_reduction_function = (
58
  )
59
 
60
  with st.spinner(text="Loading KenLM model..."):
61
- kenlm_model = KenlmModel.from_pretrained(language)
 
 
 
 
 
 
 
62
 
63
  if uploaded_file or hub_dataset:
64
  with st.spinner("Loading dataset..."):
@@ -84,7 +127,7 @@ if uploaded_file or hub_dataset:
84
  logger.info(
85
  f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
86
  )
87
- plot = generate_plot(
88
  df,
89
  text_column,
90
  "perplexity",
@@ -93,7 +136,12 @@ if uploaded_file or hub_dataset:
93
  model,
94
  seed=SEED,
95
  context_logger=st.spinner,
 
96
  )
97
- logger.info("Displaying plot")
98
  st.bokeh_chart(plot)
 
 
 
 
99
  logger.info("Done")
 
3
 
4
  import streamlit as st
5
  from embedding_lenses.data import uploaded_file_to_dataframe
6
+ from embedding_lenses.dimensionality_reduction import (
7
+ get_tsne_embeddings,
8
+ get_umap_embeddings,
9
+ )
10
  from embedding_lenses.embedding import load_model
11
 
12
+ from perplexity_lenses import REGISTRY_DATASET
13
+ from perplexity_lenses.data import (
14
+ documents_df_to_sentences_df,
15
+ hub_dataset_to_dataframe,
16
+ )
17
+ from perplexity_lenses.engine import (
18
+ DIMENSIONALITY_REDUCTION_ALGORITHMS,
19
+ DOCUMENT_TYPES,
20
+ EMBEDDING_MODELS,
21
+ LANGUAGES,
22
+ PERPLEXITY_MODELS,
23
+ SEED,
24
+ generate_plot,
25
+ )
26
  from perplexity_lenses.perplexity import KenlmModel
27
+ from perplexity_lenses.visualization import draw_histogram
28
 
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
 
55
  doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
56
  with col7:
57
  sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
58
+ perplexity_model = st.selectbox(
59
+ "Dataset on which the perplexity model was trained on", PERPLEXITY_MODELS, 0
60
+ ).lower()
61
 
62
  dimensionality_reduction = st.selectbox(
63
  "Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
64
  )
65
  model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
66
 
67
+ advanced_options = st.checkbox(
68
+ "Advanced options (do not modify if using default KenLM models).", value=False
69
+ )
70
+ lower_case = True
71
+ remove_accents = True
72
+ normalize_numbers = True
73
+ punctuation = 1
74
+ if advanced_options:
75
+ lower_case = st.checkbox(
76
+ "Lower case text for KenLM preprocessing (from cc_net)", value=False
77
+ )
78
+ remove_accents = st.checkbox(
79
+ "Remove accents for KenLM preprocessing (from cc_net)", value=False
80
+ )
81
+ normalize_numbers = st.checkbox(
82
+ "Replace numbers with zeros KenLM preprocessing (from cc_net)", value=True
83
+ )
84
+ punctuation = st.number_input(
85
+ "Punctuation mode to use from cc_net KenLM preprocessing", 1, 2, 1
86
+ )
87
+
88
  with st.spinner(text="Loading embedding model..."):
89
  model = load_model(model_name)
90
  dimensionality_reduction_function = (
 
94
  )
95
 
96
  with st.spinner(text="Loading KenLM model..."):
97
+ kenlm_model = KenlmModel.from_pretrained(
98
+ perplexity_model,
99
+ language,
100
+ lower_case,
101
+ remove_accents,
102
+ normalize_numbers,
103
+ punctuation,
104
+ )
105
 
106
  if uploaded_file or hub_dataset:
107
  with st.spinner("Loading dataset..."):
 
127
  logger.info(
128
  f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
129
  )
130
+ plot, plot_registry = generate_plot(
131
  df,
132
  text_column,
133
  "perplexity",
 
136
  model,
137
  seed=SEED,
138
  context_logger=st.spinner,
139
+ hub_dataset=hub_dataset,
140
  )
141
+ logger.info("Displaying plots")
142
  st.bokeh_chart(plot)
143
+ if hub_dataset == REGISTRY_DATASET:
144
+ st.bokeh_chart(plot_registry)
145
+ fig = draw_histogram(df["perplexity"].values)
146
+ st.pyplot(fig)
147
  logger.info("Done")
cli.py CHANGED
@@ -6,17 +6,28 @@ import pandas as pd
6
  import typer
7
  from bokeh.plotting import output_file as bokeh_output_file
8
  from bokeh.plotting import save
9
- from embedding_lenses.data import uploaded_file_to_dataframe
10
- from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings,
11
- get_umap_embeddings)
 
12
  from embedding_lenses.embedding import load_model
13
 
14
- from perplexity_lenses.data import (documents_df_to_sentences_df,
15
- hub_dataset_to_dataframe)
16
- from perplexity_lenses.engine import (DIMENSIONALITY_REDUCTION_ALGORITHMS,
17
- DOCUMENT_TYPES, EMBEDDING_MODELS,
18
- LANGUAGES, SEED, generate_plot)
 
 
 
 
 
 
 
 
 
19
  from perplexity_lenses.perplexity import KenlmModel
 
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
@@ -46,6 +57,10 @@ def main(
46
  help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
47
  ),
48
  sample: int = typer.Option(1000, help="Maximum number of examples to use."),
 
 
 
 
49
  dimensionality_reduction: str = typer.Option(
50
  DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
51
  help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
@@ -55,7 +70,7 @@ def main(
55
  help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
56
  ),
57
  output_file: str = typer.Option(
58
- "perplexity.html", help="The name of the output visualization HTML file."
59
  ),
60
  ):
61
  """
@@ -69,7 +84,14 @@ def main(
69
  else partial(get_tsne_embeddings, random_state=SEED)
70
  )
71
  logger.info("Loading KenLM model...")
72
- kenlm_model = KenlmModel.from_pretrained(language)
 
 
 
 
 
 
 
73
  logger.info("Loading dataset...")
74
  if dataset.endswith(".csv") or dataset.endswith(".tsv"):
75
  df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
@@ -92,7 +114,7 @@ def main(
92
  logger.info(
93
  f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
94
  )
95
- plot = generate_plot(
96
  df,
97
  text_column,
98
  "perplexity",
@@ -100,10 +122,16 @@ def main(
100
  dimensionality_reduction_function,
101
  model,
102
  seed=SEED,
 
103
  )
104
- logger.info("Saving plot")
105
- bokeh_output_file(output_file)
106
  save(plot)
 
 
 
 
 
107
  logger.info("Done")
108
 
109
 
 
6
  import typer
7
  from bokeh.plotting import output_file as bokeh_output_file
8
  from bokeh.plotting import save
9
+ from embedding_lenses.dimensionality_reduction import (
10
+ get_tsne_embeddings,
11
+ get_umap_embeddings,
12
+ )
13
  from embedding_lenses.embedding import load_model
14
 
15
+ from perplexity_lenses import REGISTRY_DATASET
16
+ from perplexity_lenses.data import (
17
+ documents_df_to_sentences_df,
18
+ hub_dataset_to_dataframe,
19
+ )
20
+ from perplexity_lenses.engine import (
21
+ DIMENSIONALITY_REDUCTION_ALGORITHMS,
22
+ DOCUMENT_TYPES,
23
+ EMBEDDING_MODELS,
24
+ LANGUAGES,
25
+ PERPLEXITY_MODELS,
26
+ SEED,
27
+ generate_plot,
28
+ )
29
  from perplexity_lenses.perplexity import KenlmModel
30
+ from perplexity_lenses.visualization import draw_histogram
31
 
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger(__name__)
 
57
  help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
58
  ),
59
  sample: int = typer.Option(1000, help="Maximum number of examples to use."),
60
+ perplexity_model: str = typer.Option(
61
+ "wikipedia",
62
+ help=f"Dataset on which the perplexity model was trained on. Options: {PERPLEXITY_MODELS}",
63
+ ),
64
  dimensionality_reduction: str = typer.Option(
65
  DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
66
  help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
 
70
  help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
71
  ),
72
  output_file: str = typer.Option(
73
+ "perplexity", help="The name of the output visualization files."
74
  ),
75
  ):
76
  """
 
84
  else partial(get_tsne_embeddings, random_state=SEED)
85
  )
86
  logger.info("Loading KenLM model...")
87
+ kenlm_model = KenlmModel.from_pretrained(
88
+ perplexity_model.lower(),
89
+ language,
90
+ lower_case=True,
91
+ remove_accents=True,
92
+ normalize_numbers=True,
93
+ punctuation=1,
94
+ )
95
  logger.info("Loading dataset...")
96
  if dataset.endswith(".csv") or dataset.endswith(".tsv"):
97
  df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
 
114
  logger.info(
115
  f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
116
  )
117
+ plot, plot_registry = generate_plot(
118
  df,
119
  text_column,
120
  "perplexity",
 
122
  dimensionality_reduction_function,
123
  model,
124
  seed=SEED,
125
+ hub_dataset=dataset,
126
  )
127
+ logger.info("Saving plots")
128
+ bokeh_output_file(f"{output_file}.html")
129
  save(plot)
130
+ if dataset == REGISTRY_DATASET:
131
+ bokeh_output_file(f"{output_file}_registry.html")
132
+ save(plot_registry)
133
+ fig = draw_histogram(df["perplexity"].values)
134
+ fig.savefig(f"{output_file}_histogram.png")
135
  logger.info("Done")
136
 
137
 
perplexity_lenses/__init__.py CHANGED
@@ -1 +1,2 @@
1
  __version__ = "0.1.0"
 
 
1
  __version__ = "0.1.0"
2
+ REGISTRY_DATASET = "mhtoin/register_oscar"
perplexity_lenses/data.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  from datasets import load_dataset
6
  from tqdm import tqdm
7
 
 
8
  from perplexity_lenses.perplexity import KenlmModel
9
 
10
 
@@ -21,13 +22,22 @@ def hub_dataset_to_dataframe(
21
  load_dataset_fn = partial(load_dataset, path=path)
22
  if name:
23
  load_dataset_fn = partial(load_dataset_fn, name=name)
 
 
 
24
  if split:
25
  load_dataset_fn = partial(load_dataset_fn, split=split)
26
  dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
27
  if doc_type.lower() == "sentence":
28
  dataset = dataset.map(
29
  lambda x: [
30
- {text_column: sentence, "perplexity": model.get_perplexity(sentence)}
 
 
 
 
 
 
31
  for sentence in x[text_column].split("\n")
32
  ]
33
  )
@@ -36,6 +46,9 @@ def hub_dataset_to_dataframe(
36
  lambda x: {
37
  text_column: x[text_column],
38
  "perplexity": model.get_perplexity(x[text_column]),
 
 
 
39
  }
40
  )
41
  instances = []
 
5
  from datasets import load_dataset
6
  from tqdm import tqdm
7
 
8
+ from perplexity_lenses import REGISTRY_DATASET
9
  from perplexity_lenses.perplexity import KenlmModel
10
 
11
 
 
22
  load_dataset_fn = partial(load_dataset, path=path)
23
  if name:
24
  load_dataset_fn = partial(load_dataset_fn, name=name)
25
+ # Special case for the registry dataset
26
+ if path == REGISTRY_DATASET:
27
+ load_dataset_fn = partial(load_dataset_fn, data_files=f"{name}/*")
28
  if split:
29
  load_dataset_fn = partial(load_dataset_fn, split=split)
30
  dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
31
  if doc_type.lower() == "sentence":
32
  dataset = dataset.map(
33
  lambda x: [
34
+ {
35
+ text_column: sentence,
36
+ "perplexity": model.get_perplexity(sentence),
37
+ "label": x.get("labels", [])[0]
38
+ if len(x.get("labels", [])) > 0
39
+ else "NONE", # Special case for registry dataset
40
+ }
41
  for sentence in x[text_column].split("\n")
42
  ]
43
  )
 
46
  lambda x: {
47
  text_column: x[text_column],
48
  "perplexity": model.get_perplexity(x[text_column]),
49
+ "label": x.get("labels", [])[0]
50
+ if len(x.get("labels", [])) > 0
51
+ else "NONE", # Special case for registry dataset
52
  }
53
  )
54
  instances = []
perplexity_lenses/engine.py CHANGED
@@ -1,15 +1,17 @@
1
  import logging
2
  import time
3
- from typing import Callable, Optional, Union
4
 
5
  import pandas as pd
6
  import streamlit as st
 
7
  from bokeh.plotting import Figure
8
  from embedding_lenses.embedding import embed_text
9
  from embedding_lenses.utils import encode_labels
 
10
  from sentence_transformers import SentenceTransformer
11
 
12
- from perplexity_lenses.visualization import draw_interactive_scatter_plot
13
 
14
  logger = logging.getLogger(__name__)
15
  EMBEDDING_MODELS = [
@@ -71,6 +73,7 @@ LANGUAGES = [
71
  "uk",
72
  "zh",
73
  ]
 
74
 
75
 
76
  class ContextLogger:
@@ -94,7 +97,8 @@ def generate_plot(
94
  model: SentenceTransformer,
95
  seed: int = 0,
96
  context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
97
- ) -> Figure:
 
98
  if text_column not in df.columns:
99
  raise ValueError(
100
  f"The specified column name doesn't exist. Columns available: {df.columns.values}"
@@ -111,13 +115,32 @@ def generate_plot(
111
  with context_logger("Reducing dimensionality..."):
112
  embeddings_2d = dimensionality_reduction_function(embeddings)
113
  logger.info("Generating figure")
 
 
 
 
 
 
114
  plot = draw_interactive_scatter_plot(
115
- df[text_column].values,
116
  embeddings_2d[:, 0],
117
  embeddings_2d[:, 1],
118
- encoded_labels.values,
119
- df[label_column].values,
120
- text_column,
121
- label_column,
122
  )
123
- return plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import time
3
+ from typing import Callable, Optional, Tuple, Union
4
 
5
  import pandas as pd
6
  import streamlit as st
7
+ from bokeh.palettes import Turbo256
8
  from bokeh.plotting import Figure
9
  from embedding_lenses.embedding import embed_text
10
  from embedding_lenses.utils import encode_labels
11
+ from embedding_lenses.visualization import draw_interactive_scatter_plot
12
  from sentence_transformers import SentenceTransformer
13
 
14
+ from perplexity_lenses import REGISTRY_DATASET
15
 
16
  logger = logging.getLogger(__name__)
17
  EMBEDDING_MODELS = [
 
73
  "uk",
74
  "zh",
75
  ]
76
+ PERPLEXITY_MODELS = ["Wikipedia", "OSCAR"]
77
 
78
 
79
  class ContextLogger:
 
97
  model: SentenceTransformer,
98
  seed: int = 0,
99
  context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
100
+ hub_dataset: str = "",
101
+ ) -> Tuple[Figure, Optional[Figure]]:
102
  if text_column not in df.columns:
103
  raise ValueError(
104
  f"The specified column name doesn't exist. Columns available: {df.columns.values}"
 
115
  with context_logger("Reducing dimensionality..."):
116
  embeddings_2d = dimensionality_reduction_function(embeddings)
117
  logger.info("Generating figure")
118
+ hover_data = {
119
+ text_column: df[text_column].values,
120
+ label_column: encoded_labels.values,
121
+ }
122
+ # Round perplexity values
123
+ values = df[label_column].values.round().astype(int)
124
  plot = draw_interactive_scatter_plot(
125
+ hover_data,
126
  embeddings_2d[:, 0],
127
  embeddings_2d[:, 1],
128
+ values,
 
 
 
129
  )
130
+ # Special case for the registry dataset
131
+ plot_registry = None
132
+ if hub_dataset == REGISTRY_DATASET:
133
+ encoded_labels = encode_labels(df["label"])
134
+ hover_data = {
135
+ text_column: df[text_column].values,
136
+ "label": df["label"].values,
137
+ label_column: df[label_column].values,
138
+ }
139
+ plot_registry = draw_interactive_scatter_plot(
140
+ hover_data,
141
+ embeddings_2d[:, 0],
142
+ embeddings_2d[:, 1],
143
+ encoded_labels.values,
144
+ palette=Turbo256,
145
+ )
146
+ return plot, plot_registry
perplexity_lenses/perplexity.py CHANGED
@@ -1,11 +1,13 @@
1
  import os
2
  import re
3
  import unicodedata
4
- import urllib.request
5
  from typing import Dict
6
 
7
  import kenlm
8
  import sentencepiece
 
 
 
9
 
10
 
11
  class SentencePiece:
@@ -64,30 +66,65 @@ class KenlmModel:
64
  non_printing_chars_re = re.compile(
65
  f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
66
  )
 
 
67
 
68
- def __init__(self, language):
69
- download_kenlm_model(language)
 
 
 
 
 
 
 
 
70
  try:
71
- self.model = kenlm.Model(f"{language}.arpa.bin")
72
- self.tokenizer = SentencePiece(f"{language}.sp.model")
73
  except OSError:
74
- os.remove(f"{language}.arpa.bin")
75
- if os.path.exists(f"{language}.sp.model"):
76
- os.remove(f"{language}.sp.model")
77
  raise OSError(
78
  "File was corrupt and should have been removed. Please, retry."
79
  )
 
 
 
 
80
 
81
  @classmethod
82
- def from_pretrained(cls, language: str):
83
- return cls(language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def pp(self, log_score, length):
86
  return 10.0 ** (-log_score / length)
87
 
88
  def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
89
  if normalize_cc_net:
90
- doc = self.normalize(doc)
 
 
 
 
 
 
91
  # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
92
  doc = self.tokenizer.do(doc)
93
  doc_log_score, doc_length = 0, 0
@@ -140,16 +177,12 @@ class KenlmModel:
140
  def remove_non_printing_char(self, text: str) -> str:
141
  return self.non_printing_chars_re.sub("", text)
142
 
143
-
144
- def download_kenlm_model(language: str):
145
- root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
146
- bin_name = f"{language}.arpa.bin"
147
- model_name = f"{language}.sp.model"
148
- bin_url = f"{root_url}/{bin_name}"
149
- model_url = f"{root_url}/{model_name}"
150
-
151
- if not os.path.isfile(bin_name):
152
- urllib.request.urlretrieve(bin_url, bin_name)
153
-
154
- if not os.path.isfile(model_name):
155
- urllib.request.urlretrieve(model_url, model_name)
 
1
  import os
2
  import re
3
  import unicodedata
 
4
  from typing import Dict
5
 
6
  import kenlm
7
  import sentencepiece
8
+ from huggingface_hub import cached_download, hf_hub_url
9
+
10
+ KENLM_MODEL_REPO = "edugp/kenlm"
11
 
12
 
13
  class SentencePiece:
 
66
  non_printing_chars_re = re.compile(
67
  f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
68
  )
69
+ kenlm_model_dir = None
70
+ sentence_piece_model_dir = None
71
 
72
+ def __init__(
73
+ self,
74
+ model_dataset: str,
75
+ language: str,
76
+ lower_case: bool = False,
77
+ remove_accents: bool = False,
78
+ normalize_numbers: bool = True,
79
+ punctuation: int = 1,
80
+ ):
81
+ self.download_kenlm_model(model_dataset, language)
82
  try:
83
+ self.model = kenlm.Model(self.kenlm_model_dir)
84
+ self.tokenizer = SentencePiece(self.sentence_piece_model_dir)
85
  except OSError:
86
+ os.remove(self.kenlm_model_dir)
87
+ if os.path.exists(self.sentence_piece_model_dir):
88
+ os.remove(self.sentence_piece_model_dir)
89
  raise OSError(
90
  "File was corrupt and should have been removed. Please, retry."
91
  )
92
+ self.accent = remove_accents
93
+ self.case = lower_case
94
+ self.numbers = normalize_numbers
95
+ self.punct = punctuation
96
 
97
  @classmethod
98
+ def from_pretrained(
99
+ cls,
100
+ model_dataset: str,
101
+ language: str,
102
+ lower_case: bool,
103
+ remove_accents: bool,
104
+ normalize_numbers: bool,
105
+ punctuation: int,
106
+ ):
107
+ return cls(
108
+ model_dataset,
109
+ language,
110
+ lower_case,
111
+ remove_accents,
112
+ normalize_numbers,
113
+ punctuation,
114
+ )
115
 
116
  def pp(self, log_score, length):
117
  return 10.0 ** (-log_score / length)
118
 
119
  def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
120
  if normalize_cc_net:
121
+ doc = self.normalize(
122
+ doc,
123
+ accent=self.accent,
124
+ case=self.case,
125
+ numbers=self.numbers,
126
+ punct=self.punct,
127
+ )
128
  # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
129
  doc = self.tokenizer.do(doc)
130
  doc_log_score, doc_length = 0, 0
 
177
  def remove_non_printing_char(self, text: str) -> str:
178
  return self.non_printing_chars_re.sub("", text)
179
 
180
+ def download_kenlm_model(self, model_dataset: str, language: str):
181
+ kenlm_model_url = hf_hub_url(
182
+ KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.arpa.bin"
183
+ )
184
+ self.kenlm_model_dir = cached_download(kenlm_model_url)
185
+ sentence_piece_model_url = hf_hub_url(
186
+ KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.sp.model"
187
+ )
188
+ self.sentence_piece_model_dir = cached_download(sentence_piece_model_url)
 
 
 
 
perplexity_lenses/visualization.py CHANGED
@@ -1,57 +1,18 @@
 
 
1
  import numpy as np
2
- from bokeh.models import ColumnDataSource, HoverTool
3
- from bokeh.palettes import Cividis256 as Pallete
4
- from bokeh.plotting import Figure, figure
5
- from bokeh.transform import factor_cmap
6
 
7
 
8
- def draw_interactive_scatter_plot(
9
- texts: np.ndarray,
10
- xs: np.ndarray,
11
- ys: np.ndarray,
12
  values: np.ndarray,
13
- labels: np.ndarray,
14
- text_column: str,
15
- label_column: str,
16
- ) -> Figure:
17
- # Smooth down values for coloring, by taking the entropy = log10(perplexity) and multiply it by 10000
18
- values = ((np.log10(values)) * 10000).round().astype(int)
19
- # Normalize values to range between 0-255, to assign a color for each value
20
- max_value = values.max()
21
- min_value = values.min()
22
- if max_value - min_value == 0:
23
- values_color = np.ones(len(values))
24
- else:
25
- values_color = (
26
- ((values - min_value) / (max_value - min_value) * 255).round().astype(int)
27
- )
28
- values_color_sorted = sorted(values_color)
29
-
30
- values_list = values.astype(str).tolist()
31
- values_sorted = sorted(values_list)
32
- labels_list = labels.astype(str).tolist()
33
-
34
- source = ColumnDataSource(
35
- data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)
36
- )
37
- hover = HoverTool(
38
- tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]
39
- )
40
- p = figure(plot_width=800, plot_height=800, tools=[hover])
41
- p.circle(
42
- "x",
43
- "y",
44
- size=10,
45
- source=source,
46
- fill_color=factor_cmap(
47
- "label",
48
- palette=[Pallete[id_] for id_ in values_color_sorted],
49
- factors=values_sorted,
50
- ),
51
- )
52
-
53
- p.axis.visible = False
54
- p.xgrid.grid_line_color = None
55
- p.ygrid.grid_line_color = None
56
- p.toolbar.logo = None
57
- return p
 
1
+ import matplotlib.figure
2
+ import matplotlib.pyplot as plt
3
  import numpy as np
 
 
 
 
4
 
5
 
6
+ def draw_histogram(
 
 
 
7
  values: np.ndarray,
8
+ cutoff_x_axis: float = 2000.0,
9
+ title: str = "Perplexity histogram",
10
+ xlabel: str = "Perplexity",
11
+ ) -> matplotlib.figure.Figure:
12
+ hist_values = values[values < cutoff_x_axis]
13
+ fig, ax = plt.subplots(figsize=(12, 9))
14
+ ax.hist(hist_values, bins=50)
15
+ ax.set_title(title)
16
+ ax.set_xlabel(xlabel)
17
+ ax.set_ylabel("Counts")
18
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  bokeh==2.2.2
2
- https://files.pythonhosted.org/packages/2f/58/e00d2495b54f4ba97ca31a11aa7e636f80183ccf9b616f7eaa5518d050bb/embedding_lenses-0.5.0-py3-none-any.whl
3
  https://github.com/kpu/kenlm/archive/master.zip
4
  huggingface-hub==0.0.19
5
  numpy==1.20.0
 
1
  bokeh==2.2.2
2
+ https://files.pythonhosted.org/packages/52/b6/798b1ac755c19cc5a91e5f5068544ed9deda74cb9339e4bb5e2b5f3ff3ec/embedding_lenses-0.8.0-py3-none-any.whl
3
  https://github.com/kpu/kenlm/archive/master.zip
4
  huggingface-hub==0.0.19
5
  numpy==1.20.0