Spaces:
Sleeping
Sleeping
Update options, force GC, tweak parameters and add flags
Browse files- app/cli.py +26 -7
- app/model.py +7 -5
app/cli.py
CHANGED
@@ -111,8 +111,8 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
111 |
)
|
112 |
@click.option(
|
113 |
"--processes",
|
114 |
-
default=
|
115 |
-
help="Number of parallel jobs
|
116 |
show_default=True,
|
117 |
)
|
118 |
@click.option(
|
@@ -129,6 +129,8 @@ def evaluate(
|
|
129 |
verbose: bool,
|
130 |
) -> None:
|
131 |
"""Evaluate the model on the the specified dataset"""
|
|
|
|
|
132 |
import joblib
|
133 |
|
134 |
from app.constants import CACHE_DIR
|
@@ -155,13 +157,21 @@ def evaluate(
|
|
155 |
click.echo(DONE_STR)
|
156 |
|
157 |
del text_data
|
|
|
158 |
|
159 |
click.echo("Loading model... ", nl=False)
|
160 |
model = joblib.load(model_path)
|
161 |
click.echo(DONE_STR)
|
162 |
|
163 |
click.echo("Evaluating model... ", nl=False)
|
164 |
-
acc_mean, acc_std = evaluate_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
166 |
|
167 |
|
@@ -206,10 +216,15 @@ def evaluate(
|
|
206 |
type=click.IntRange(-1, None),
|
207 |
)
|
208 |
@click.option(
|
209 |
-
"--
|
210 |
is_flag=True,
|
211 |
help="Overwrite the model file if it already exists",
|
212 |
)
|
|
|
|
|
|
|
|
|
|
|
213 |
@click.option(
|
214 |
"--verbose",
|
215 |
is_flag=True,
|
@@ -222,10 +237,13 @@ def train(
|
|
222 |
batch_size: int,
|
223 |
processes: int,
|
224 |
seed: int,
|
225 |
-
|
|
|
226 |
verbose: bool,
|
227 |
) -> None:
|
228 |
"""Train the model on the provided dataset"""
|
|
|
|
|
229 |
import joblib
|
230 |
|
231 |
from app.constants import CACHE_DIR, MODELS_DIR
|
@@ -233,12 +251,12 @@ def train(
|
|
233 |
from app.model import train_model
|
234 |
|
235 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
236 |
-
if model_path.exists() and not
|
237 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
238 |
|
239 |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
240 |
use_cached_data = False
|
241 |
-
if cached_data_path.exists():
|
242 |
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
|
243 |
|
244 |
if use_cached_data:
|
@@ -256,6 +274,7 @@ def train(
|
|
256 |
click.echo(DONE_STR)
|
257 |
|
258 |
del text_data
|
|
|
259 |
|
260 |
click.echo("Training model... ")
|
261 |
model, accuracy = train_model(
|
|
|
111 |
)
|
112 |
@click.option(
|
113 |
"--processes",
|
114 |
+
default=4,
|
115 |
+
help="Number of parallel jobs to run",
|
116 |
show_default=True,
|
117 |
)
|
118 |
@click.option(
|
|
|
129 |
verbose: bool,
|
130 |
) -> None:
|
131 |
"""Evaluate the model on the the specified dataset"""
|
132 |
+
import gc
|
133 |
+
|
134 |
import joblib
|
135 |
|
136 |
from app.constants import CACHE_DIR
|
|
|
157 |
click.echo(DONE_STR)
|
158 |
|
159 |
del text_data
|
160 |
+
gc.collect()
|
161 |
|
162 |
click.echo("Loading model... ", nl=False)
|
163 |
model = joblib.load(model_path)
|
164 |
click.echo(DONE_STR)
|
165 |
|
166 |
click.echo("Evaluating model... ", nl=False)
|
167 |
+
acc_mean, acc_std = evaluate_model(
|
168 |
+
model,
|
169 |
+
token_data,
|
170 |
+
label_data,
|
171 |
+
folds=cv,
|
172 |
+
n_jobs=processes,
|
173 |
+
verbose=verbose,
|
174 |
+
)
|
175 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
176 |
|
177 |
|
|
|
216 |
type=click.IntRange(-1, None),
|
217 |
)
|
218 |
@click.option(
|
219 |
+
"--overwrite",
|
220 |
is_flag=True,
|
221 |
help="Overwrite the model file if it already exists",
|
222 |
)
|
223 |
+
@click.option(
|
224 |
+
"--skip-cache",
|
225 |
+
is_flag=True,
|
226 |
+
help="Ignore cached tokenized data",
|
227 |
+
)
|
228 |
@click.option(
|
229 |
"--verbose",
|
230 |
is_flag=True,
|
|
|
237 |
batch_size: int,
|
238 |
processes: int,
|
239 |
seed: int,
|
240 |
+
overwrite: bool,
|
241 |
+
skip_cache: bool,
|
242 |
verbose: bool,
|
243 |
) -> None:
|
244 |
"""Train the model on the provided dataset"""
|
245 |
+
import gc
|
246 |
+
|
247 |
import joblib
|
248 |
|
249 |
from app.constants import CACHE_DIR, MODELS_DIR
|
|
|
251 |
from app.model import train_model
|
252 |
|
253 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
254 |
+
if model_path.exists() and not overwrite:
|
255 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
256 |
|
257 |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
258 |
use_cached_data = False
|
259 |
+
if cached_data_path.exists() and not skip_cache:
|
260 |
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
|
261 |
|
262 |
if use_cached_data:
|
|
|
274 |
click.echo(DONE_STR)
|
275 |
|
276 |
del text_data
|
277 |
+
gc.collect()
|
278 |
|
279 |
click.echo("Training model... ")
|
280 |
model, accuracy = train_model(
|
app/model.py
CHANGED
@@ -99,14 +99,14 @@ def train_model(
|
|
99 |
cv=folds,
|
100 |
random_state=seed,
|
101 |
n_jobs=n_jobs,
|
102 |
-
verbose=verbose,
|
103 |
scoring="accuracy",
|
104 |
n_iter=10,
|
105 |
)
|
106 |
|
107 |
-
|
108 |
search.fit(text_train, label_train)
|
109 |
-
|
110 |
|
111 |
best_model = search.best_estimator_
|
112 |
return best_model, best_model.score(text_test, label_test)
|
@@ -117,6 +117,7 @@ def evaluate_model(
|
|
117 |
token_data: list[str],
|
118 |
label_data: list[int],
|
119 |
folds: int = 5,
|
|
|
120 |
verbose: bool = False,
|
121 |
) -> tuple[float, float]:
|
122 |
"""Evaluate the model using cross-validation.
|
@@ -126,6 +127,7 @@ def evaluate_model(
|
|
126 |
token_data: Tokenized text data
|
127 |
label_data: Label data
|
128 |
folds: Number of cross-validation folds
|
|
|
129 |
verbose: Whether to output additional information
|
130 |
|
131 |
Returns:
|
@@ -138,8 +140,8 @@ def evaluate_model(
|
|
138 |
label_data,
|
139 |
cv=folds,
|
140 |
scoring="accuracy",
|
141 |
-
n_jobs
|
142 |
-
verbose=verbose,
|
143 |
)
|
144 |
del os.environ["PYTHONWARNINGS"]
|
145 |
return scores.mean(), scores.std()
|
|
|
99 |
cv=folds,
|
100 |
random_state=seed,
|
101 |
n_jobs=n_jobs,
|
102 |
+
verbose=2 if verbose else 0,
|
103 |
scoring="accuracy",
|
104 |
n_iter=10,
|
105 |
)
|
106 |
|
107 |
+
os.environ["PYTHONWARNINGS"] = "ignore"
|
108 |
search.fit(text_train, label_train)
|
109 |
+
del os.environ["PYTHONWARNINGS"]
|
110 |
|
111 |
best_model = search.best_estimator_
|
112 |
return best_model, best_model.score(text_test, label_test)
|
|
|
117 |
token_data: list[str],
|
118 |
label_data: list[int],
|
119 |
folds: int = 5,
|
120 |
+
n_jobs: int = 4,
|
121 |
verbose: bool = False,
|
122 |
) -> tuple[float, float]:
|
123 |
"""Evaluate the model using cross-validation.
|
|
|
127 |
token_data: Tokenized text data
|
128 |
label_data: Label data
|
129 |
folds: Number of cross-validation folds
|
130 |
+
n_jobs: Number of parallel jobs
|
131 |
verbose: Whether to output additional information
|
132 |
|
133 |
Returns:
|
|
|
140 |
label_data,
|
141 |
cv=folds,
|
142 |
scoring="accuracy",
|
143 |
+
n_jobs=n_jobs,
|
144 |
+
verbose=2 if verbose else 0,
|
145 |
)
|
146 |
del os.environ["PYTHONWARNINGS"]
|
147 |
return scores.mean(), scores.std()
|