Tymec's picture
Completely change the structure of the project
85ac990
raw
history blame
3.75 kB
from __future__ import annotations
from pathlib import Path
from typing import Literal
import click
__all__ = ["cli_wrapper"]
ERROR_STR = click.style("ERROR", fg="red")
DONE_STR = click.style("DONE", fg="green")
POSITIVE_STR = click.style("POSITIVE", fg="green")
NEUTRAL_STR = click.style("NEUTRAL", fg="yellow")
NEGATIVE_STR = click.style("NEGATIVE", fg="red")
@click.group()
def cli() -> None: ...
@cli.command()
@click.option(
"--model",
"model_path",
required=True,
help="Path to the trained model",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
)
@click.option(
"--share/--no-share",
default=False,
help="Whether to create a shareable link",
)
def gui(model_path: Path, share: bool) -> None:
"""Launch the Gradio GUI"""
from app.gui import launch_gui
launch_gui(model_path, share)
@cli.command()
@click.option(
"--model",
"model_path",
required=True,
help="Path to the trained model",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
)
@click.argument("text", nargs=-1)
def predict(model_path: Path, text: list[str]) -> None:
"""Perform sentiment analysis on the provided text.
Note: Piped input takes precedence over the text argument
"""
import sys
import joblib
text = " ".join(text).strip()
if not sys.stdin.isatty():
piped_text = sys.stdin.read().strip()
text = piped_text or text
if not text:
click.echo(f"{ERROR_STR}: No text provided")
return
click.echo("Loading model... ", nl=False)
model = joblib.load(model_path)
click.echo(DONE_STR)
click.echo("Performing sentiment analysis... ", nl=False)
prediction = model.predict([text])[0]
if prediction == 0:
sentiment = NEGATIVE_STR
elif prediction == 1:
sentiment = POSITIVE_STR
else:
sentiment = NEUTRAL_STR
click.echo(sentiment)
@cli.command()
@click.option(
"--dataset",
required=True,
help="Dataset to train the model on",
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
)
@click.option(
"--max-features",
default=20000,
help="Maximum number of features",
show_default=True,
type=click.IntRange(1, None),
)
@click.option(
"--seed",
default=42,
help="Random seed (-1 for random seed)",
show_default=True,
type=click.IntRange(-1, None),
)
def train(
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
max_features: int,
seed: int,
) -> None:
"""Train the model on the provided dataset"""
import joblib
from app.constants import MODELS_DIR
from app.model import create_model, load_data, train_model
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
if model_path.exists():
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
click.echo("Preprocessing dataset... ", nl=False)
text_data, label_data = load_data(dataset)
click.echo(DONE_STR)
click.echo("Creating model... ", nl=False)
model = create_model(max_features, seed=None if seed == -1 else seed)
click.echo(DONE_STR)
click.echo("Training model... ", nl=False)
accuracy = train_model(model, text_data, label_data)
joblib.dump(model, model_path)
click.echo(DONE_STR)
click.echo("Model accuracy: ")
click.secho(f"{accuracy:.2%}", fg="blue")
# TODO: Add hyperparameter options
# TODO: Random/grid search for finding best classifier and hyperparameters
def cli_wrapper() -> None:
cli(max_content_width=120)
if __name__ == "__main__":
cli_wrapper()