Spaces:
Running
Running
Cache label data along with tokenized text data
Browse files- app/cli.py +26 -20
- app/utils.py +2 -3
app/cli.py
CHANGED
@@ -146,32 +146,35 @@ def evaluate(
|
|
146 |
from app.model import evaluate_model
|
147 |
from app.utils import deserialize, serialize
|
148 |
|
149 |
-
|
|
|
150 |
use_cached_data = False
|
151 |
|
152 |
-
if
|
153 |
use_cached_data = force_cache or click.confirm(
|
154 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
155 |
default=True,
|
156 |
)
|
157 |
|
158 |
-
click.echo("Loading dataset... ", nl=False)
|
159 |
-
text_data, label_data = load_data(dataset)
|
160 |
-
click.echo(DONE_STR)
|
161 |
-
|
162 |
if use_cached_data:
|
163 |
click.echo("Loading cached data... ", nl=False)
|
164 |
-
token_data = pd.Series(deserialize(
|
|
|
165 |
click.echo(DONE_STR)
|
166 |
else:
|
|
|
|
|
|
|
|
|
167 |
click.echo("Tokenizing data... ")
|
168 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
169 |
|
170 |
click.echo("Caching tokenized data... ")
|
171 |
-
serialize(token_data,
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
|
176 |
click.echo("Size of vocabulary: ", nl=False)
|
177 |
vocab = token_data.explode().value_counts()
|
@@ -281,32 +284,35 @@ def train(
|
|
281 |
if model_path.exists() and not overwrite:
|
282 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
283 |
|
284 |
-
|
|
|
285 |
use_cached_data = False
|
286 |
|
287 |
-
if
|
288 |
use_cached_data = force_cache or click.confirm(
|
289 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
290 |
default=True,
|
291 |
)
|
292 |
|
293 |
-
click.echo("Loading dataset... ", nl=False)
|
294 |
-
text_data, label_data = load_data(dataset)
|
295 |
-
click.echo(DONE_STR)
|
296 |
-
|
297 |
if use_cached_data:
|
298 |
click.echo("Loading cached data... ", nl=False)
|
299 |
-
token_data = pd.Series(deserialize(
|
|
|
300 |
click.echo(DONE_STR)
|
301 |
else:
|
|
|
|
|
|
|
|
|
302 |
click.echo("Tokenizing data... ")
|
303 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
304 |
|
305 |
click.echo("Caching tokenized data... ")
|
306 |
-
serialize(token_data,
|
|
|
307 |
|
308 |
-
|
309 |
-
|
310 |
|
311 |
click.echo("Size of vocabulary: ", nl=False)
|
312 |
vocab = token_data.explode().value_counts()
|
|
|
146 |
from app.model import evaluate_model
|
147 |
from app.utils import deserialize, serialize
|
148 |
|
149 |
+
token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
|
150 |
+
label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
|
151 |
use_cached_data = False
|
152 |
|
153 |
+
if token_cache_path.exists():
|
154 |
use_cached_data = force_cache or click.confirm(
|
155 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
156 |
default=True,
|
157 |
)
|
158 |
|
|
|
|
|
|
|
|
|
159 |
if use_cached_data:
|
160 |
click.echo("Loading cached data... ", nl=False)
|
161 |
+
token_data = pd.Series(deserialize(token_cache_path))
|
162 |
+
label_data = joblib.load(label_cache_path)
|
163 |
click.echo(DONE_STR)
|
164 |
else:
|
165 |
+
click.echo("Loading dataset... ", nl=False)
|
166 |
+
text_data, label_data = load_data(dataset)
|
167 |
+
click.echo(DONE_STR)
|
168 |
+
|
169 |
click.echo("Tokenizing data... ")
|
170 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
171 |
|
172 |
click.echo("Caching tokenized data... ")
|
173 |
+
serialize(token_data, token_cache_path, show_progress=True)
|
174 |
+
joblib.dump(label_data, label_cache_path, compress=3)
|
175 |
|
176 |
+
del text_data
|
177 |
+
gc.collect()
|
178 |
|
179 |
click.echo("Size of vocabulary: ", nl=False)
|
180 |
vocab = token_data.explode().value_counts()
|
|
|
284 |
if model_path.exists() and not overwrite:
|
285 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
286 |
|
287 |
+
token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
|
288 |
+
label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
|
289 |
use_cached_data = False
|
290 |
|
291 |
+
if token_cache_path.exists():
|
292 |
use_cached_data = force_cache or click.confirm(
|
293 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
294 |
default=True,
|
295 |
)
|
296 |
|
|
|
|
|
|
|
|
|
297 |
if use_cached_data:
|
298 |
click.echo("Loading cached data... ", nl=False)
|
299 |
+
token_data = pd.Series(deserialize(token_cache_path))
|
300 |
+
label_data = joblib.load(label_cache_path)
|
301 |
click.echo(DONE_STR)
|
302 |
else:
|
303 |
+
click.echo("Loading dataset... ", nl=False)
|
304 |
+
text_data, label_data = load_data(dataset)
|
305 |
+
click.echo(DONE_STR)
|
306 |
+
|
307 |
click.echo("Tokenizing data... ")
|
308 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
309 |
|
310 |
click.echo("Caching tokenized data... ")
|
311 |
+
serialize(token_data, token_cache_path, show_progress=True)
|
312 |
+
joblib.dump(label_data, label_cache_path, compress=3)
|
313 |
|
314 |
+
del text_data
|
315 |
+
gc.collect()
|
316 |
|
317 |
click.echo("Size of vocabulary: ", nl=False)
|
318 |
vocab = token_data.explode().value_counts()
|
app/utils.py
CHANGED
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|
11 |
__all__ = ["serialize", "deserialize"]
|
12 |
|
13 |
|
14 |
-
def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
|
15 |
"""Serialize data to a file
|
16 |
|
17 |
Args:
|
@@ -20,7 +20,6 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
|
|
20 |
max_size: The maximum size a chunk can be (in elements)
|
21 |
show_progress: Whether to show a progress bar
|
22 |
"""
|
23 |
-
# first file is path, next chunks have ".1", ".2", etc. appended
|
24 |
for i, chunk in enumerate(
|
25 |
tqdm(
|
26 |
[data[i : i + max_size] for i in range(0, len(data), max_size)],
|
@@ -33,7 +32,7 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
|
|
33 |
joblib.dump(chunk, f, compress=3)
|
34 |
|
35 |
|
36 |
-
def deserialize(path: Path) -> Sequence[str]:
|
37 |
"""Deserialize data from a file
|
38 |
|
39 |
Args:
|
|
|
11 |
__all__ = ["serialize", "deserialize"]
|
12 |
|
13 |
|
14 |
+
def serialize(data: Sequence[str | int], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
|
15 |
"""Serialize data to a file
|
16 |
|
17 |
Args:
|
|
|
20 |
max_size: The maximum size a chunk can be (in elements)
|
21 |
show_progress: Whether to show a progress bar
|
22 |
"""
|
|
|
23 |
for i, chunk in enumerate(
|
24 |
tqdm(
|
25 |
[data[i : i + max_size] for i in range(0, len(data), max_size)],
|
|
|
32 |
joblib.dump(chunk, f, compress=3)
|
33 |
|
34 |
|
35 |
+
def deserialize(path: Path) -> Sequence[str | int]:
|
36 |
"""Deserialize data from a file
|
37 |
|
38 |
Args:
|