Spaces:
Runtime error
Runtime error
Sync with data tooling repo, using edugp/kenlm models, updating viz to use quantiles for coloring and ad-hoc viz for the registry dataset
Browse files- README.md +5 -1
- app.py +58 -10
- cli.py +41 -13
- perplexity_lenses/__init__.py +1 -0
- perplexity_lenses/data.py +14 -1
- perplexity_lenses/engine.py +32 -9
- perplexity_lenses/perplexity.py +57 -24
- perplexity_lenses/visualization.py +14 -53
- requirements.txt +1 -1
README.md
CHANGED
@@ -11,7 +11,11 @@ pinned: false
|
|
11 |
# Installation:
|
12 |
Requires Python >= 3.7 and < 3.10
|
13 |
```
|
14 |
-
pip install
|
|
|
|
|
|
|
|
|
15 |
```
|
16 |
|
17 |
# Web App:
|
|
|
11 |
# Installation:
|
12 |
Requires Python >= 3.7 and < 3.10
|
13 |
```
|
14 |
+
pip install .
|
15 |
+
```
|
16 |
+
Or with [poetry](https://python-poetry.org/)
|
17 |
+
```
|
18 |
+
poetry install
|
19 |
```
|
20 |
|
21 |
# Web App:
|
app.py
CHANGED
@@ -3,16 +3,28 @@ from functools import partial
|
|
3 |
|
4 |
import streamlit as st
|
5 |
from embedding_lenses.data import uploaded_file_to_dataframe
|
6 |
-
from embedding_lenses.dimensionality_reduction import (
|
7 |
-
|
|
|
|
|
8 |
from embedding_lenses.embedding import load_model
|
9 |
|
10 |
-
from perplexity_lenses
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from perplexity_lenses.perplexity import KenlmModel
|
|
|
16 |
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
@@ -43,12 +55,36 @@ with col6:
|
|
43 |
doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
|
44 |
with col7:
|
45 |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
|
|
|
|
|
|
|
46 |
|
47 |
dimensionality_reduction = st.selectbox(
|
48 |
"Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
|
49 |
)
|
50 |
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
with st.spinner(text="Loading embedding model..."):
|
53 |
model = load_model(model_name)
|
54 |
dimensionality_reduction_function = (
|
@@ -58,7 +94,14 @@ dimensionality_reduction_function = (
|
|
58 |
)
|
59 |
|
60 |
with st.spinner(text="Loading KenLM model..."):
|
61 |
-
kenlm_model = KenlmModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if uploaded_file or hub_dataset:
|
64 |
with st.spinner("Loading dataset..."):
|
@@ -84,7 +127,7 @@ if uploaded_file or hub_dataset:
|
|
84 |
logger.info(
|
85 |
f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
|
86 |
)
|
87 |
-
plot = generate_plot(
|
88 |
df,
|
89 |
text_column,
|
90 |
"perplexity",
|
@@ -93,7 +136,12 @@ if uploaded_file or hub_dataset:
|
|
93 |
model,
|
94 |
seed=SEED,
|
95 |
context_logger=st.spinner,
|
|
|
96 |
)
|
97 |
-
logger.info("Displaying
|
98 |
st.bokeh_chart(plot)
|
|
|
|
|
|
|
|
|
99 |
logger.info("Done")
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
from embedding_lenses.data import uploaded_file_to_dataframe
|
6 |
+
from embedding_lenses.dimensionality_reduction import (
|
7 |
+
get_tsne_embeddings,
|
8 |
+
get_umap_embeddings,
|
9 |
+
)
|
10 |
from embedding_lenses.embedding import load_model
|
11 |
|
12 |
+
from perplexity_lenses import REGISTRY_DATASET
|
13 |
+
from perplexity_lenses.data import (
|
14 |
+
documents_df_to_sentences_df,
|
15 |
+
hub_dataset_to_dataframe,
|
16 |
+
)
|
17 |
+
from perplexity_lenses.engine import (
|
18 |
+
DIMENSIONALITY_REDUCTION_ALGORITHMS,
|
19 |
+
DOCUMENT_TYPES,
|
20 |
+
EMBEDDING_MODELS,
|
21 |
+
LANGUAGES,
|
22 |
+
PERPLEXITY_MODELS,
|
23 |
+
SEED,
|
24 |
+
generate_plot,
|
25 |
+
)
|
26 |
from perplexity_lenses.perplexity import KenlmModel
|
27 |
+
from perplexity_lenses.visualization import draw_histogram
|
28 |
|
29 |
logging.basicConfig(level=logging.INFO)
|
30 |
logger = logging.getLogger(__name__)
|
|
|
55 |
doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
|
56 |
with col7:
|
57 |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
|
58 |
+
perplexity_model = st.selectbox(
|
59 |
+
"Dataset on which the perplexity model was trained on", PERPLEXITY_MODELS, 0
|
60 |
+
).lower()
|
61 |
|
62 |
dimensionality_reduction = st.selectbox(
|
63 |
"Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
|
64 |
)
|
65 |
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
|
66 |
|
67 |
+
advanced_options = st.checkbox(
|
68 |
+
"Advanced options (do not modify if using default KenLM models).", value=False
|
69 |
+
)
|
70 |
+
lower_case = True
|
71 |
+
remove_accents = True
|
72 |
+
normalize_numbers = True
|
73 |
+
punctuation = 1
|
74 |
+
if advanced_options:
|
75 |
+
lower_case = st.checkbox(
|
76 |
+
"Lower case text for KenLM preprocessing (from cc_net)", value=False
|
77 |
+
)
|
78 |
+
remove_accents = st.checkbox(
|
79 |
+
"Remove accents for KenLM preprocessing (from cc_net)", value=False
|
80 |
+
)
|
81 |
+
normalize_numbers = st.checkbox(
|
82 |
+
"Replace numbers with zeros KenLM preprocessing (from cc_net)", value=True
|
83 |
+
)
|
84 |
+
punctuation = st.number_input(
|
85 |
+
"Punctuation mode to use from cc_net KenLM preprocessing", 1, 2, 1
|
86 |
+
)
|
87 |
+
|
88 |
with st.spinner(text="Loading embedding model..."):
|
89 |
model = load_model(model_name)
|
90 |
dimensionality_reduction_function = (
|
|
|
94 |
)
|
95 |
|
96 |
with st.spinner(text="Loading KenLM model..."):
|
97 |
+
kenlm_model = KenlmModel.from_pretrained(
|
98 |
+
perplexity_model,
|
99 |
+
language,
|
100 |
+
lower_case,
|
101 |
+
remove_accents,
|
102 |
+
normalize_numbers,
|
103 |
+
punctuation,
|
104 |
+
)
|
105 |
|
106 |
if uploaded_file or hub_dataset:
|
107 |
with st.spinner("Loading dataset..."):
|
|
|
127 |
logger.info(
|
128 |
f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
|
129 |
)
|
130 |
+
plot, plot_registry = generate_plot(
|
131 |
df,
|
132 |
text_column,
|
133 |
"perplexity",
|
|
|
136 |
model,
|
137 |
seed=SEED,
|
138 |
context_logger=st.spinner,
|
139 |
+
hub_dataset=hub_dataset,
|
140 |
)
|
141 |
+
logger.info("Displaying plots")
|
142 |
st.bokeh_chart(plot)
|
143 |
+
if hub_dataset == REGISTRY_DATASET:
|
144 |
+
st.bokeh_chart(plot_registry)
|
145 |
+
fig = draw_histogram(df["perplexity"].values)
|
146 |
+
st.pyplot(fig)
|
147 |
logger.info("Done")
|
cli.py
CHANGED
@@ -6,17 +6,28 @@ import pandas as pd
|
|
6 |
import typer
|
7 |
from bokeh.plotting import output_file as bokeh_output_file
|
8 |
from bokeh.plotting import save
|
9 |
-
from embedding_lenses.
|
10 |
-
|
11 |
-
|
|
|
12 |
from embedding_lenses.embedding import load_model
|
13 |
|
14 |
-
from perplexity_lenses
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from perplexity_lenses.perplexity import KenlmModel
|
|
|
20 |
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
@@ -46,6 +57,10 @@ def main(
|
|
46 |
help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
|
47 |
),
|
48 |
sample: int = typer.Option(1000, help="Maximum number of examples to use."),
|
|
|
|
|
|
|
|
|
49 |
dimensionality_reduction: str = typer.Option(
|
50 |
DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
|
51 |
help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
|
@@ -55,7 +70,7 @@ def main(
|
|
55 |
help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
|
56 |
),
|
57 |
output_file: str = typer.Option(
|
58 |
-
"perplexity
|
59 |
),
|
60 |
):
|
61 |
"""
|
@@ -69,7 +84,14 @@ def main(
|
|
69 |
else partial(get_tsne_embeddings, random_state=SEED)
|
70 |
)
|
71 |
logger.info("Loading KenLM model...")
|
72 |
-
kenlm_model = KenlmModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
logger.info("Loading dataset...")
|
74 |
if dataset.endswith(".csv") or dataset.endswith(".tsv"):
|
75 |
df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
|
@@ -92,7 +114,7 @@ def main(
|
|
92 |
logger.info(
|
93 |
f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
|
94 |
)
|
95 |
-
plot = generate_plot(
|
96 |
df,
|
97 |
text_column,
|
98 |
"perplexity",
|
@@ -100,10 +122,16 @@ def main(
|
|
100 |
dimensionality_reduction_function,
|
101 |
model,
|
102 |
seed=SEED,
|
|
|
103 |
)
|
104 |
-
logger.info("Saving
|
105 |
-
bokeh_output_file(output_file)
|
106 |
save(plot)
|
|
|
|
|
|
|
|
|
|
|
107 |
logger.info("Done")
|
108 |
|
109 |
|
|
|
6 |
import typer
|
7 |
from bokeh.plotting import output_file as bokeh_output_file
|
8 |
from bokeh.plotting import save
|
9 |
+
from embedding_lenses.dimensionality_reduction import (
|
10 |
+
get_tsne_embeddings,
|
11 |
+
get_umap_embeddings,
|
12 |
+
)
|
13 |
from embedding_lenses.embedding import load_model
|
14 |
|
15 |
+
from perplexity_lenses import REGISTRY_DATASET
|
16 |
+
from perplexity_lenses.data import (
|
17 |
+
documents_df_to_sentences_df,
|
18 |
+
hub_dataset_to_dataframe,
|
19 |
+
)
|
20 |
+
from perplexity_lenses.engine import (
|
21 |
+
DIMENSIONALITY_REDUCTION_ALGORITHMS,
|
22 |
+
DOCUMENT_TYPES,
|
23 |
+
EMBEDDING_MODELS,
|
24 |
+
LANGUAGES,
|
25 |
+
PERPLEXITY_MODELS,
|
26 |
+
SEED,
|
27 |
+
generate_plot,
|
28 |
+
)
|
29 |
from perplexity_lenses.perplexity import KenlmModel
|
30 |
+
from perplexity_lenses.visualization import draw_histogram
|
31 |
|
32 |
logging.basicConfig(level=logging.INFO)
|
33 |
logger = logging.getLogger(__name__)
|
|
|
57 |
help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
|
58 |
),
|
59 |
sample: int = typer.Option(1000, help="Maximum number of examples to use."),
|
60 |
+
perplexity_model: str = typer.Option(
|
61 |
+
"wikipedia",
|
62 |
+
help=f"Dataset on which the perplexity model was trained on. Options: {PERPLEXITY_MODELS}",
|
63 |
+
),
|
64 |
dimensionality_reduction: str = typer.Option(
|
65 |
DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
|
66 |
help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
|
|
|
70 |
help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
|
71 |
),
|
72 |
output_file: str = typer.Option(
|
73 |
+
"perplexity", help="The name of the output visualization files."
|
74 |
),
|
75 |
):
|
76 |
"""
|
|
|
84 |
else partial(get_tsne_embeddings, random_state=SEED)
|
85 |
)
|
86 |
logger.info("Loading KenLM model...")
|
87 |
+
kenlm_model = KenlmModel.from_pretrained(
|
88 |
+
perplexity_model.lower(),
|
89 |
+
language,
|
90 |
+
lower_case=True,
|
91 |
+
remove_accents=True,
|
92 |
+
normalize_numbers=True,
|
93 |
+
punctuation=1,
|
94 |
+
)
|
95 |
logger.info("Loading dataset...")
|
96 |
if dataset.endswith(".csv") or dataset.endswith(".tsv"):
|
97 |
df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
|
|
|
114 |
logger.info(
|
115 |
f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
|
116 |
)
|
117 |
+
plot, plot_registry = generate_plot(
|
118 |
df,
|
119 |
text_column,
|
120 |
"perplexity",
|
|
|
122 |
dimensionality_reduction_function,
|
123 |
model,
|
124 |
seed=SEED,
|
125 |
+
hub_dataset=dataset,
|
126 |
)
|
127 |
+
logger.info("Saving plots")
|
128 |
+
bokeh_output_file(f"{output_file}.html")
|
129 |
save(plot)
|
130 |
+
if dataset == REGISTRY_DATASET:
|
131 |
+
bokeh_output_file(f"{output_file}_registry.html")
|
132 |
+
save(plot_registry)
|
133 |
+
fig = draw_histogram(df["perplexity"].values)
|
134 |
+
fig.savefig(f"{output_file}_histogram.png")
|
135 |
logger.info("Done")
|
136 |
|
137 |
|
perplexity_lenses/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
__version__ = "0.1.0"
|
|
|
|
1 |
__version__ = "0.1.0"
|
2 |
+
REGISTRY_DATASET = "mhtoin/register_oscar"
|
perplexity_lenses/data.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
from datasets import load_dataset
|
6 |
from tqdm import tqdm
|
7 |
|
|
|
8 |
from perplexity_lenses.perplexity import KenlmModel
|
9 |
|
10 |
|
@@ -21,13 +22,22 @@ def hub_dataset_to_dataframe(
|
|
21 |
load_dataset_fn = partial(load_dataset, path=path)
|
22 |
if name:
|
23 |
load_dataset_fn = partial(load_dataset_fn, name=name)
|
|
|
|
|
|
|
24 |
if split:
|
25 |
load_dataset_fn = partial(load_dataset_fn, split=split)
|
26 |
dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
|
27 |
if doc_type.lower() == "sentence":
|
28 |
dataset = dataset.map(
|
29 |
lambda x: [
|
30 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
for sentence in x[text_column].split("\n")
|
32 |
]
|
33 |
)
|
@@ -36,6 +46,9 @@ def hub_dataset_to_dataframe(
|
|
36 |
lambda x: {
|
37 |
text_column: x[text_column],
|
38 |
"perplexity": model.get_perplexity(x[text_column]),
|
|
|
|
|
|
|
39 |
}
|
40 |
)
|
41 |
instances = []
|
|
|
5 |
from datasets import load_dataset
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
+
from perplexity_lenses import REGISTRY_DATASET
|
9 |
from perplexity_lenses.perplexity import KenlmModel
|
10 |
|
11 |
|
|
|
22 |
load_dataset_fn = partial(load_dataset, path=path)
|
23 |
if name:
|
24 |
load_dataset_fn = partial(load_dataset_fn, name=name)
|
25 |
+
# Special case for the registry dataset
|
26 |
+
if path == REGISTRY_DATASET:
|
27 |
+
load_dataset_fn = partial(load_dataset_fn, data_files=f"{name}/*")
|
28 |
if split:
|
29 |
load_dataset_fn = partial(load_dataset_fn, split=split)
|
30 |
dataset = load_dataset_fn(streaming=True).shuffle(buffer_size=10000, seed=seed)
|
31 |
if doc_type.lower() == "sentence":
|
32 |
dataset = dataset.map(
|
33 |
lambda x: [
|
34 |
+
{
|
35 |
+
text_column: sentence,
|
36 |
+
"perplexity": model.get_perplexity(sentence),
|
37 |
+
"label": x.get("labels", [])[0]
|
38 |
+
if len(x.get("labels", [])) > 0
|
39 |
+
else "NONE", # Special case for registry dataset
|
40 |
+
}
|
41 |
for sentence in x[text_column].split("\n")
|
42 |
]
|
43 |
)
|
|
|
46 |
lambda x: {
|
47 |
text_column: x[text_column],
|
48 |
"perplexity": model.get_perplexity(x[text_column]),
|
49 |
+
"label": x.get("labels", [])[0]
|
50 |
+
if len(x.get("labels", [])) > 0
|
51 |
+
else "NONE", # Special case for registry dataset
|
52 |
}
|
53 |
)
|
54 |
instances = []
|
perplexity_lenses/engine.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
import logging
|
2 |
import time
|
3 |
-
from typing import Callable, Optional, Union
|
4 |
|
5 |
import pandas as pd
|
6 |
import streamlit as st
|
|
|
7 |
from bokeh.plotting import Figure
|
8 |
from embedding_lenses.embedding import embed_text
|
9 |
from embedding_lenses.utils import encode_labels
|
|
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
|
12 |
-
from perplexity_lenses
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
EMBEDDING_MODELS = [
|
@@ -71,6 +73,7 @@ LANGUAGES = [
|
|
71 |
"uk",
|
72 |
"zh",
|
73 |
]
|
|
|
74 |
|
75 |
|
76 |
class ContextLogger:
|
@@ -94,7 +97,8 @@ def generate_plot(
|
|
94 |
model: SentenceTransformer,
|
95 |
seed: int = 0,
|
96 |
context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
|
97 |
-
|
|
|
98 |
if text_column not in df.columns:
|
99 |
raise ValueError(
|
100 |
f"The specified column name doesn't exist. Columns available: {df.columns.values}"
|
@@ -111,13 +115,32 @@ def generate_plot(
|
|
111 |
with context_logger("Reducing dimensionality..."):
|
112 |
embeddings_2d = dimensionality_reduction_function(embeddings)
|
113 |
logger.info("Generating figure")
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
plot = draw_interactive_scatter_plot(
|
115 |
-
|
116 |
embeddings_2d[:, 0],
|
117 |
embeddings_2d[:, 1],
|
118 |
-
|
119 |
-
df[label_column].values,
|
120 |
-
text_column,
|
121 |
-
label_column,
|
122 |
)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import time
|
3 |
+
from typing import Callable, Optional, Tuple, Union
|
4 |
|
5 |
import pandas as pd
|
6 |
import streamlit as st
|
7 |
+
from bokeh.palettes import Turbo256
|
8 |
from bokeh.plotting import Figure
|
9 |
from embedding_lenses.embedding import embed_text
|
10 |
from embedding_lenses.utils import encode_labels
|
11 |
+
from embedding_lenses.visualization import draw_interactive_scatter_plot
|
12 |
from sentence_transformers import SentenceTransformer
|
13 |
|
14 |
+
from perplexity_lenses import REGISTRY_DATASET
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
EMBEDDING_MODELS = [
|
|
|
73 |
"uk",
|
74 |
"zh",
|
75 |
]
|
76 |
+
PERPLEXITY_MODELS = ["Wikipedia", "OSCAR"]
|
77 |
|
78 |
|
79 |
class ContextLogger:
|
|
|
97 |
model: SentenceTransformer,
|
98 |
seed: int = 0,
|
99 |
context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
|
100 |
+
hub_dataset: str = "",
|
101 |
+
) -> Tuple[Figure, Optional[Figure]]:
|
102 |
if text_column not in df.columns:
|
103 |
raise ValueError(
|
104 |
f"The specified column name doesn't exist. Columns available: {df.columns.values}"
|
|
|
115 |
with context_logger("Reducing dimensionality..."):
|
116 |
embeddings_2d = dimensionality_reduction_function(embeddings)
|
117 |
logger.info("Generating figure")
|
118 |
+
hover_data = {
|
119 |
+
text_column: df[text_column].values,
|
120 |
+
label_column: encoded_labels.values,
|
121 |
+
}
|
122 |
+
# Round perplexity values
|
123 |
+
values = df[label_column].values.round().astype(int)
|
124 |
plot = draw_interactive_scatter_plot(
|
125 |
+
hover_data,
|
126 |
embeddings_2d[:, 0],
|
127 |
embeddings_2d[:, 1],
|
128 |
+
values,
|
|
|
|
|
|
|
129 |
)
|
130 |
+
# Special case for the registry dataset
|
131 |
+
plot_registry = None
|
132 |
+
if hub_dataset == REGISTRY_DATASET:
|
133 |
+
encoded_labels = encode_labels(df["label"])
|
134 |
+
hover_data = {
|
135 |
+
text_column: df[text_column].values,
|
136 |
+
"label": df["label"].values,
|
137 |
+
label_column: df[label_column].values,
|
138 |
+
}
|
139 |
+
plot_registry = draw_interactive_scatter_plot(
|
140 |
+
hover_data,
|
141 |
+
embeddings_2d[:, 0],
|
142 |
+
embeddings_2d[:, 1],
|
143 |
+
encoded_labels.values,
|
144 |
+
palette=Turbo256,
|
145 |
+
)
|
146 |
+
return plot, plot_registry
|
perplexity_lenses/perplexity.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
import unicodedata
|
4 |
-
import urllib.request
|
5 |
from typing import Dict
|
6 |
|
7 |
import kenlm
|
8 |
import sentencepiece
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class SentencePiece:
|
@@ -64,30 +66,65 @@ class KenlmModel:
|
|
64 |
non_printing_chars_re = re.compile(
|
65 |
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
|
66 |
)
|
|
|
|
|
67 |
|
68 |
-
def __init__(
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
try:
|
71 |
-
self.model = kenlm.Model(
|
72 |
-
self.tokenizer = SentencePiece(
|
73 |
except OSError:
|
74 |
-
os.remove(
|
75 |
-
if os.path.exists(
|
76 |
-
os.remove(
|
77 |
raise OSError(
|
78 |
"File was corrupt and should have been removed. Please, retry."
|
79 |
)
|
|
|
|
|
|
|
|
|
80 |
|
81 |
@classmethod
|
82 |
-
def from_pretrained(
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def pp(self, log_score, length):
|
86 |
return 10.0 ** (-log_score / length)
|
87 |
|
88 |
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
89 |
if normalize_cc_net:
|
90 |
-
doc = self.normalize(
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
# Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
|
92 |
doc = self.tokenizer.do(doc)
|
93 |
doc_log_score, doc_length = 0, 0
|
@@ -140,16 +177,12 @@ class KenlmModel:
|
|
140 |
def remove_non_printing_char(self, text: str) -> str:
|
141 |
return self.non_printing_chars_re.sub("", text)
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
urllib.request.urlretrieve(bin_url, bin_name)
|
153 |
-
|
154 |
-
if not os.path.isfile(model_name):
|
155 |
-
urllib.request.urlretrieve(model_url, model_name)
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import unicodedata
|
|
|
4 |
from typing import Dict
|
5 |
|
6 |
import kenlm
|
7 |
import sentencepiece
|
8 |
+
from huggingface_hub import cached_download, hf_hub_url
|
9 |
+
|
10 |
+
KENLM_MODEL_REPO = "edugp/kenlm"
|
11 |
|
12 |
|
13 |
class SentencePiece:
|
|
|
66 |
non_printing_chars_re = re.compile(
|
67 |
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
|
68 |
)
|
69 |
+
kenlm_model_dir = None
|
70 |
+
sentence_piece_model_dir = None
|
71 |
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
model_dataset: str,
|
75 |
+
language: str,
|
76 |
+
lower_case: bool = False,
|
77 |
+
remove_accents: bool = False,
|
78 |
+
normalize_numbers: bool = True,
|
79 |
+
punctuation: int = 1,
|
80 |
+
):
|
81 |
+
self.download_kenlm_model(model_dataset, language)
|
82 |
try:
|
83 |
+
self.model = kenlm.Model(self.kenlm_model_dir)
|
84 |
+
self.tokenizer = SentencePiece(self.sentence_piece_model_dir)
|
85 |
except OSError:
|
86 |
+
os.remove(self.kenlm_model_dir)
|
87 |
+
if os.path.exists(self.sentence_piece_model_dir):
|
88 |
+
os.remove(self.sentence_piece_model_dir)
|
89 |
raise OSError(
|
90 |
"File was corrupt and should have been removed. Please, retry."
|
91 |
)
|
92 |
+
self.accent = remove_accents
|
93 |
+
self.case = lower_case
|
94 |
+
self.numbers = normalize_numbers
|
95 |
+
self.punct = punctuation
|
96 |
|
97 |
@classmethod
|
98 |
+
def from_pretrained(
|
99 |
+
cls,
|
100 |
+
model_dataset: str,
|
101 |
+
language: str,
|
102 |
+
lower_case: bool,
|
103 |
+
remove_accents: bool,
|
104 |
+
normalize_numbers: bool,
|
105 |
+
punctuation: int,
|
106 |
+
):
|
107 |
+
return cls(
|
108 |
+
model_dataset,
|
109 |
+
language,
|
110 |
+
lower_case,
|
111 |
+
remove_accents,
|
112 |
+
normalize_numbers,
|
113 |
+
punctuation,
|
114 |
+
)
|
115 |
|
116 |
def pp(self, log_score, length):
|
117 |
return 10.0 ** (-log_score / length)
|
118 |
|
119 |
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
120 |
if normalize_cc_net:
|
121 |
+
doc = self.normalize(
|
122 |
+
doc,
|
123 |
+
accent=self.accent,
|
124 |
+
case=self.case,
|
125 |
+
numbers=self.numbers,
|
126 |
+
punct=self.punct,
|
127 |
+
)
|
128 |
# Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
|
129 |
doc = self.tokenizer.do(doc)
|
130 |
doc_log_score, doc_length = 0, 0
|
|
|
177 |
def remove_non_printing_char(self, text: str) -> str:
|
178 |
return self.non_printing_chars_re.sub("", text)
|
179 |
|
180 |
+
def download_kenlm_model(self, model_dataset: str, language: str):
|
181 |
+
kenlm_model_url = hf_hub_url(
|
182 |
+
KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.arpa.bin"
|
183 |
+
)
|
184 |
+
self.kenlm_model_dir = cached_download(kenlm_model_url)
|
185 |
+
sentence_piece_model_url = hf_hub_url(
|
186 |
+
KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.sp.model"
|
187 |
+
)
|
188 |
+
self.sentence_piece_model_dir = cached_download(sentence_piece_model_url)
|
|
|
|
|
|
|
|
perplexity_lenses/visualization.py
CHANGED
@@ -1,57 +1,18 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
-
from bokeh.models import ColumnDataSource, HoverTool
|
3 |
-
from bokeh.palettes import Cividis256 as Pallete
|
4 |
-
from bokeh.plotting import Figure, figure
|
5 |
-
from bokeh.transform import factor_cmap
|
6 |
|
7 |
|
8 |
-
def
|
9 |
-
texts: np.ndarray,
|
10 |
-
xs: np.ndarray,
|
11 |
-
ys: np.ndarray,
|
12 |
values: np.ndarray,
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
) -> Figure:
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
else:
|
25 |
-
values_color = (
|
26 |
-
((values - min_value) / (max_value - min_value) * 255).round().astype(int)
|
27 |
-
)
|
28 |
-
values_color_sorted = sorted(values_color)
|
29 |
-
|
30 |
-
values_list = values.astype(str).tolist()
|
31 |
-
values_sorted = sorted(values_list)
|
32 |
-
labels_list = labels.astype(str).tolist()
|
33 |
-
|
34 |
-
source = ColumnDataSource(
|
35 |
-
data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)
|
36 |
-
)
|
37 |
-
hover = HoverTool(
|
38 |
-
tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]
|
39 |
-
)
|
40 |
-
p = figure(plot_width=800, plot_height=800, tools=[hover])
|
41 |
-
p.circle(
|
42 |
-
"x",
|
43 |
-
"y",
|
44 |
-
size=10,
|
45 |
-
source=source,
|
46 |
-
fill_color=factor_cmap(
|
47 |
-
"label",
|
48 |
-
palette=[Pallete[id_] for id_ in values_color_sorted],
|
49 |
-
factors=values_sorted,
|
50 |
-
),
|
51 |
-
)
|
52 |
-
|
53 |
-
p.axis.visible = False
|
54 |
-
p.xgrid.grid_line_color = None
|
55 |
-
p.ygrid.grid_line_color = None
|
56 |
-
p.toolbar.logo = None
|
57 |
-
return p
|
|
|
1 |
+
import matplotlib.figure
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
+
def draw_histogram(
|
|
|
|
|
|
|
7 |
values: np.ndarray,
|
8 |
+
cutoff_x_axis: float = 2000.0,
|
9 |
+
title: str = "Perplexity histogram",
|
10 |
+
xlabel: str = "Perplexity",
|
11 |
+
) -> matplotlib.figure.Figure:
|
12 |
+
hist_values = values[values < cutoff_x_axis]
|
13 |
+
fig, ax = plt.subplots(figsize=(12, 9))
|
14 |
+
ax.hist(hist_values, bins=50)
|
15 |
+
ax.set_title(title)
|
16 |
+
ax.set_xlabel(xlabel)
|
17 |
+
ax.set_ylabel("Counts")
|
18 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
bokeh==2.2.2
|
2 |
-
https://files.pythonhosted.org/packages/
|
3 |
https://github.com/kpu/kenlm/archive/master.zip
|
4 |
huggingface-hub==0.0.19
|
5 |
numpy==1.20.0
|
|
|
1 |
bokeh==2.2.2
|
2 |
+
https://files.pythonhosted.org/packages/52/b6/798b1ac755c19cc5a91e5f5068544ed9deda74cb9339e4bb5e2b5f3ff3ec/embedding_lenses-0.8.0-py3-none-any.whl
|
3 |
https://github.com/kpu/kenlm/archive/master.zip
|
4 |
huggingface-hub==0.0.19
|
5 |
numpy==1.20.0
|