Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from pathlib import Path | |
from typing import Literal | |
import click | |
__all__ = ["cli_wrapper"] | |
DONE_STR = click.style("DONE", fg="green") | |
def cli() -> None: ... | |
def gui(model_path: Path, share: bool) -> None: | |
"""Launch the Gradio GUI""" | |
import os | |
from app.gui import launch_gui | |
os.environ["MODEL_PATH"] = model_path.as_posix() | |
launch_gui(share) | |
def predict(model_path: Path, text: list[str]) -> None: | |
"""Perform sentiment analysis on the provided text. | |
Note: Piped input takes precedence over the text argument | |
""" | |
import sys | |
import joblib | |
from app.model import infer_model | |
text = " ".join(text).strip() | |
if not sys.stdin.isatty(): | |
piped_text = sys.stdin.read().strip() | |
text = piped_text or text | |
if not text: | |
msg = "No text provided" | |
raise click.UsageError(msg) | |
click.echo("Loading model... ", nl=False) | |
model = joblib.load(model_path) | |
click.echo(DONE_STR) | |
click.echo("Performing sentiment analysis... ", nl=False) | |
prediction = infer_model(model, [text])[0] | |
# prediction = model.predict([text])[0] | |
if prediction == 0: | |
sentiment = click.style("NEGATIVE", fg="red") | |
elif prediction == 1: | |
sentiment = click.style("POSITIVE", fg="green") | |
else: | |
sentiment = click.style("NEUTRAL", fg="yellow") | |
click.echo(sentiment) | |
def evaluate( | |
dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"], | |
model_path: Path, | |
cv: int, | |
batch_size: int, | |
processes: int, | |
verbose: bool, | |
) -> None: | |
"""Evaluate the model on the the specified dataset""" | |
import joblib | |
from app.constants import CACHE_DIR | |
from app.data import load_data, tokenize | |
from app.model import evaluate_model | |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl" | |
use_cached_data = False | |
if cached_data_path.exists(): | |
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True) | |
if use_cached_data: | |
click.echo("Loading cached data... ", nl=False) | |
token_data, label_data = joblib.load(cached_data_path) | |
click.echo(DONE_STR) | |
else: | |
click.echo("Loading dataset... ", nl=False) | |
text_data, label_data = load_data(dataset) | |
click.echo(DONE_STR) | |
click.echo("Tokenizing data... ", nl=False) | |
token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True) | |
joblib.dump((token_data, label_data), cached_data_path, compress=3) | |
click.echo(DONE_STR) | |
del text_data | |
click.echo("Loading model... ", nl=False) | |
model = joblib.load(model_path) | |
click.echo(DONE_STR) | |
click.echo("Evaluating model... ", nl=False) | |
acc_mean, acc_std = evaluate_model(model, token_data, label_data, folds=cv, verbose=verbose) | |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue") | |
def train( | |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"], | |
max_features: int, | |
cv: int, | |
batch_size: int, | |
processes: int, | |
seed: int, | |
force: bool, | |
verbose: bool, | |
) -> None: | |
"""Train the model on the provided dataset""" | |
import joblib | |
from app.constants import CACHE_DIR, MODELS_DIR | |
from app.data import load_data, tokenize | |
from app.model import create_model, train_model | |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl" | |
if model_path.exists() and not force: | |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True) | |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl" | |
use_cached_data = False | |
if cached_data_path.exists(): | |
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True) | |
if use_cached_data: | |
click.echo("Loading cached data... ", nl=False) | |
token_data, label_data = joblib.load(cached_data_path) | |
click.echo(DONE_STR) | |
else: | |
click.echo("Loading dataset... ", nl=False) | |
text_data, label_data = load_data(dataset) | |
click.echo(DONE_STR) | |
click.echo("Tokenizing data... ", nl=False) | |
token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True) | |
joblib.dump((token_data, label_data), cached_data_path, compress=3) | |
click.echo(DONE_STR) | |
del text_data | |
click.echo("Training model... ") | |
model = create_model(max_features, seed=None if seed == -1 else seed, verbose=verbose) | |
trained_model, accuracy = train_model(model, token_data, label_data, folds=cv, seed=seed, verbose=verbose) | |
click.echo("Model accuracy: ", nl=False) | |
click.secho(f"{accuracy:.2%}", fg="blue") | |
click.echo("Model saved to: ", nl=False) | |
joblib.dump(trained_model, model_path, compress=3) | |
click.secho(str(model_path), fg="blue") | |
def cli_wrapper() -> None: | |
cli(max_content_width=120) | |
if __name__ == "__main__": | |
cli_wrapper() | |