Spaces:
Running
Running
from __future__ import annotations | |
from pathlib import Path | |
from typing import Literal | |
import click | |
__all__ = ["cli_wrapper"] | |
ERROR_STR = click.style("ERROR", fg="red") | |
DONE_STR = click.style("DONE", fg="green") | |
POSITIVE_STR = click.style("POSITIVE", fg="green") | |
NEUTRAL_STR = click.style("NEUTRAL", fg="yellow") | |
NEGATIVE_STR = click.style("NEGATIVE", fg="red") | |
def cli() -> None: ... | |
def gui(model_path: Path, share: bool) -> None: | |
"""Launch the Gradio GUI""" | |
from app.gui import launch_gui | |
launch_gui(model_path, share) | |
def predict(model_path: Path, text: list[str]) -> None: | |
"""Perform sentiment analysis on the provided text. | |
Note: Piped input takes precedence over the text argument | |
""" | |
import sys | |
import joblib | |
text = " ".join(text).strip() | |
if not sys.stdin.isatty(): | |
piped_text = sys.stdin.read().strip() | |
text = piped_text or text | |
if not text: | |
click.echo(f"{ERROR_STR}: No text provided") | |
return | |
click.echo("Loading model... ", nl=False) | |
model = joblib.load(model_path) | |
click.echo(DONE_STR) | |
click.echo("Performing sentiment analysis... ", nl=False) | |
prediction = model.predict([text])[0] | |
if prediction == 0: | |
sentiment = NEGATIVE_STR | |
elif prediction == 1: | |
sentiment = POSITIVE_STR | |
else: | |
sentiment = NEUTRAL_STR | |
click.echo(sentiment) | |
def train( | |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"], | |
max_features: int, | |
seed: int, | |
) -> None: | |
"""Train the model on the provided dataset""" | |
import joblib | |
from app.constants import MODELS_DIR | |
from app.model import create_model, load_data, train_model | |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl" | |
if model_path.exists(): | |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True) | |
click.echo("Preprocessing dataset... ", nl=False) | |
text_data, label_data = load_data(dataset) | |
click.echo(DONE_STR) | |
click.echo("Creating model... ", nl=False) | |
model = create_model(max_features, seed=None if seed == -1 else seed) | |
click.echo(DONE_STR) | |
click.echo("Training model... ", nl=False) | |
accuracy = train_model(model, text_data, label_data) | |
joblib.dump(model, model_path) | |
click.echo(DONE_STR) | |
click.echo("Model accuracy: ") | |
click.secho(f"{accuracy:.2%}", fg="blue") | |
# TODO: Add hyperparameter options | |
# TODO: Random/grid search for finding best classifier and hyperparameters | |
def cli_wrapper() -> None: | |
cli(max_content_width=120) | |
if __name__ == "__main__": | |
cli_wrapper() | |