Tymec commited on
Commit
18cc46a
·
1 Parent(s): 8471e78

Update options, force GC, tweak parameters and add flags

Browse files
Files changed (2) hide show
  1. app/cli.py +26 -7
  2. app/model.py +7 -5
app/cli.py CHANGED
@@ -111,8 +111,8 @@ def predict(model_path: Path, text: list[str]) -> None:
111
  )
112
  @click.option(
113
  "--processes",
114
- default=8,
115
- help="Number of parallel jobs during tokenization",
116
  show_default=True,
117
  )
118
  @click.option(
@@ -129,6 +129,8 @@ def evaluate(
129
  verbose: bool,
130
  ) -> None:
131
  """Evaluate the model on the the specified dataset"""
 
 
132
  import joblib
133
 
134
  from app.constants import CACHE_DIR
@@ -155,13 +157,21 @@ def evaluate(
155
  click.echo(DONE_STR)
156
 
157
  del text_data
 
158
 
159
  click.echo("Loading model... ", nl=False)
160
  model = joblib.load(model_path)
161
  click.echo(DONE_STR)
162
 
163
  click.echo("Evaluating model... ", nl=False)
164
- acc_mean, acc_std = evaluate_model(model, token_data, label_data, folds=cv, verbose=verbose)
 
 
 
 
 
 
 
165
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
166
 
167
 
@@ -206,10 +216,15 @@ def evaluate(
206
  type=click.IntRange(-1, None),
207
  )
208
  @click.option(
209
- "--force",
210
  is_flag=True,
211
  help="Overwrite the model file if it already exists",
212
  )
 
 
 
 
 
213
  @click.option(
214
  "--verbose",
215
  is_flag=True,
@@ -222,10 +237,13 @@ def train(
222
  batch_size: int,
223
  processes: int,
224
  seed: int,
225
- force: bool,
 
226
  verbose: bool,
227
  ) -> None:
228
  """Train the model on the provided dataset"""
 
 
229
  import joblib
230
 
231
  from app.constants import CACHE_DIR, MODELS_DIR
@@ -233,12 +251,12 @@ def train(
233
  from app.model import train_model
234
 
235
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
236
- if model_path.exists() and not force:
237
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
238
 
239
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
240
  use_cached_data = False
241
- if cached_data_path.exists():
242
  use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
243
 
244
  if use_cached_data:
@@ -256,6 +274,7 @@ def train(
256
  click.echo(DONE_STR)
257
 
258
  del text_data
 
259
 
260
  click.echo("Training model... ")
261
  model, accuracy = train_model(
 
111
  )
112
  @click.option(
113
  "--processes",
114
+ default=4,
115
+ help="Number of parallel jobs to run",
116
  show_default=True,
117
  )
118
  @click.option(
 
129
  verbose: bool,
130
  ) -> None:
131
  """Evaluate the model on the the specified dataset"""
132
+ import gc
133
+
134
  import joblib
135
 
136
  from app.constants import CACHE_DIR
 
157
  click.echo(DONE_STR)
158
 
159
  del text_data
160
+ gc.collect()
161
 
162
  click.echo("Loading model... ", nl=False)
163
  model = joblib.load(model_path)
164
  click.echo(DONE_STR)
165
 
166
  click.echo("Evaluating model... ", nl=False)
167
+ acc_mean, acc_std = evaluate_model(
168
+ model,
169
+ token_data,
170
+ label_data,
171
+ folds=cv,
172
+ n_jobs=processes,
173
+ verbose=verbose,
174
+ )
175
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
176
 
177
 
 
216
  type=click.IntRange(-1, None),
217
  )
218
  @click.option(
219
+ "--overwrite",
220
  is_flag=True,
221
  help="Overwrite the model file if it already exists",
222
  )
223
+ @click.option(
224
+ "--skip-cache",
225
+ is_flag=True,
226
+ help="Ignore cached tokenized data",
227
+ )
228
  @click.option(
229
  "--verbose",
230
  is_flag=True,
 
237
  batch_size: int,
238
  processes: int,
239
  seed: int,
240
+ overwrite: bool,
241
+ skip_cache: bool,
242
  verbose: bool,
243
  ) -> None:
244
  """Train the model on the provided dataset"""
245
+ import gc
246
+
247
  import joblib
248
 
249
  from app.constants import CACHE_DIR, MODELS_DIR
 
251
  from app.model import train_model
252
 
253
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
254
+ if model_path.exists() and not overwrite:
255
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
256
 
257
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
258
  use_cached_data = False
259
+ if cached_data_path.exists() and not skip_cache:
260
  use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
261
 
262
  if use_cached_data:
 
274
  click.echo(DONE_STR)
275
 
276
  del text_data
277
+ gc.collect()
278
 
279
  click.echo("Training model... ")
280
  model, accuracy = train_model(
app/model.py CHANGED
@@ -99,14 +99,14 @@ def train_model(
99
  cv=folds,
100
  random_state=seed,
101
  n_jobs=n_jobs,
102
- verbose=verbose,
103
  scoring="accuracy",
104
  n_iter=10,
105
  )
106
 
107
- # os.environ["PYTHONWARNINGS"] = "ignore"
108
  search.fit(text_train, label_train)
109
- # del os.environ["PYTHONWARNINGS"]
110
 
111
  best_model = search.best_estimator_
112
  return best_model, best_model.score(text_test, label_test)
@@ -117,6 +117,7 @@ def evaluate_model(
117
  token_data: list[str],
118
  label_data: list[int],
119
  folds: int = 5,
 
120
  verbose: bool = False,
121
  ) -> tuple[float, float]:
122
  """Evaluate the model using cross-validation.
@@ -126,6 +127,7 @@ def evaluate_model(
126
  token_data: Tokenized text data
127
  label_data: Label data
128
  folds: Number of cross-validation folds
 
129
  verbose: Whether to output additional information
130
 
131
  Returns:
@@ -138,8 +140,8 @@ def evaluate_model(
138
  label_data,
139
  cv=folds,
140
  scoring="accuracy",
141
- n_jobs=-1,
142
- verbose=verbose,
143
  )
144
  del os.environ["PYTHONWARNINGS"]
145
  return scores.mean(), scores.std()
 
99
  cv=folds,
100
  random_state=seed,
101
  n_jobs=n_jobs,
102
+ verbose=2 if verbose else 0,
103
  scoring="accuracy",
104
  n_iter=10,
105
  )
106
 
107
+ os.environ["PYTHONWARNINGS"] = "ignore"
108
  search.fit(text_train, label_train)
109
+ del os.environ["PYTHONWARNINGS"]
110
 
111
  best_model = search.best_estimator_
112
  return best_model, best_model.score(text_test, label_test)
 
117
  token_data: list[str],
118
  label_data: list[int],
119
  folds: int = 5,
120
+ n_jobs: int = 4,
121
  verbose: bool = False,
122
  ) -> tuple[float, float]:
123
  """Evaluate the model using cross-validation.
 
127
  token_data: Tokenized text data
128
  label_data: Label data
129
  folds: Number of cross-validation folds
130
+ n_jobs: Number of parallel jobs
131
  verbose: Whether to output additional information
132
 
133
  Returns:
 
140
  label_data,
141
  cv=folds,
142
  scoring="accuracy",
143
+ n_jobs=n_jobs,
144
+ verbose=2 if verbose else 0,
145
  )
146
  del os.environ["PYTHONWARNINGS"]
147
  return scores.mean(), scores.std()