Spaces:
Running
Running
Completely change the structure of the project
Browse files- .vscode/settings.json +1 -0
- README.md +4 -0
- app/__main__.py +6 -0
- app/cli.py +144 -0
- app/constants.py +27 -11
- app/gui.py +39 -73
- app/model.py +273 -113
- app/utils.py +0 -164
- deprecated/__init__.py +0 -0
- deprecated/main.py +0 -44
- deprecated/train.py +0 -152
- justfile +4 -6
- notebook.ipynb +152 -0
- poetry.lock +114 -1
- pyproject.toml +2 -1
.vscode/settings.json
CHANGED
@@ -23,5 +23,6 @@
|
|
23 |
"**/__pycache__": true,
|
24 |
"**/.ruff_cache": true,
|
25 |
"**/.venv": true,
|
|
|
26 |
}
|
27 |
}
|
|
|
23 |
"**/__pycache__": true,
|
24 |
"**/.ruff_cache": true,
|
25 |
"**/.venv": true,
|
26 |
+
"**/.cache": true,
|
27 |
}
|
28 |
}
|
README.md
CHANGED
@@ -7,6 +7,10 @@ Sentiment Analysis
|
|
7 |
3. Run `just install` to install the dependencies
|
8 |
4. Run `just run --help` to see the available commands
|
9 |
|
|
|
|
|
|
|
|
|
10 |
|
11 |
### TODO
|
12 |
- [ ] CLI using `click` (commands: predict, train, evaluate) with settings set via flags or environment variables
|
|
|
7 |
3. Run `just install` to install the dependencies
|
8 |
4. Run `just run --help` to see the available commands
|
9 |
|
10 |
+
### Datasets
|
11 |
+
- [Sentiment140](https://www.kaggle.com/datasets/kazanova/sentiment140)
|
12 |
+
- [IMDb](https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews)
|
13 |
+
- [Amazon Reviews](https://www.kaggle.com/datasets/bittlingmayer/amazonreviews)
|
14 |
|
15 |
### TODO
|
16 |
- [ ] CLI using `click` (commands: predict, train, evaluate) with settings set via flags or environment variables
|
app/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from app.cli import cli_wrapper as cli
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
cli()
|
app/cli.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Literal
|
5 |
+
|
6 |
+
import click
|
7 |
+
|
8 |
+
__all__ = ["cli_wrapper"]
|
9 |
+
|
10 |
+
ERROR_STR = click.style("ERROR", fg="red")
|
11 |
+
DONE_STR = click.style("DONE", fg="green")
|
12 |
+
POSITIVE_STR = click.style("POSITIVE", fg="green")
|
13 |
+
NEUTRAL_STR = click.style("NEUTRAL", fg="yellow")
|
14 |
+
NEGATIVE_STR = click.style("NEGATIVE", fg="red")
|
15 |
+
|
16 |
+
|
17 |
+
@click.group()
|
18 |
+
def cli() -> None: ...
|
19 |
+
|
20 |
+
|
21 |
+
@cli.command()
|
22 |
+
@click.option(
|
23 |
+
"--model",
|
24 |
+
"model_path",
|
25 |
+
required=True,
|
26 |
+
help="Path to the trained model",
|
27 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
28 |
+
)
|
29 |
+
@click.option(
|
30 |
+
"--share/--no-share",
|
31 |
+
default=False,
|
32 |
+
help="Whether to create a shareable link",
|
33 |
+
)
|
34 |
+
def gui(model_path: Path, share: bool) -> None:
|
35 |
+
"""Launch the Gradio GUI"""
|
36 |
+
from app.gui import launch_gui
|
37 |
+
|
38 |
+
launch_gui(model_path, share)
|
39 |
+
|
40 |
+
|
41 |
+
@cli.command()
|
42 |
+
@click.option(
|
43 |
+
"--model",
|
44 |
+
"model_path",
|
45 |
+
required=True,
|
46 |
+
help="Path to the trained model",
|
47 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
48 |
+
)
|
49 |
+
@click.argument("text", nargs=-1)
|
50 |
+
def predict(model_path: Path, text: list[str]) -> None:
|
51 |
+
"""Perform sentiment analysis on the provided text.
|
52 |
+
|
53 |
+
Note: Piped input takes precedence over the text argument
|
54 |
+
"""
|
55 |
+
import sys
|
56 |
+
|
57 |
+
import joblib
|
58 |
+
|
59 |
+
text = " ".join(text).strip()
|
60 |
+
if not sys.stdin.isatty():
|
61 |
+
piped_text = sys.stdin.read().strip()
|
62 |
+
text = piped_text or text
|
63 |
+
|
64 |
+
if not text:
|
65 |
+
click.echo(f"{ERROR_STR}: No text provided")
|
66 |
+
return
|
67 |
+
|
68 |
+
click.echo("Loading model... ", nl=False)
|
69 |
+
model = joblib.load(model_path)
|
70 |
+
click.echo(DONE_STR)
|
71 |
+
|
72 |
+
click.echo("Performing sentiment analysis... ", nl=False)
|
73 |
+
prediction = model.predict([text])[0]
|
74 |
+
if prediction == 0:
|
75 |
+
sentiment = NEGATIVE_STR
|
76 |
+
elif prediction == 1:
|
77 |
+
sentiment = POSITIVE_STR
|
78 |
+
else:
|
79 |
+
sentiment = NEUTRAL_STR
|
80 |
+
click.echo(sentiment)
|
81 |
+
|
82 |
+
|
83 |
+
@cli.command()
|
84 |
+
@click.option(
|
85 |
+
"--dataset",
|
86 |
+
required=True,
|
87 |
+
help="Dataset to train the model on",
|
88 |
+
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
|
89 |
+
)
|
90 |
+
@click.option(
|
91 |
+
"--max-features",
|
92 |
+
default=20000,
|
93 |
+
help="Maximum number of features",
|
94 |
+
show_default=True,
|
95 |
+
type=click.IntRange(1, None),
|
96 |
+
)
|
97 |
+
@click.option(
|
98 |
+
"--seed",
|
99 |
+
default=42,
|
100 |
+
help="Random seed (-1 for random seed)",
|
101 |
+
show_default=True,
|
102 |
+
type=click.IntRange(-1, None),
|
103 |
+
)
|
104 |
+
def train(
|
105 |
+
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
106 |
+
max_features: int,
|
107 |
+
seed: int,
|
108 |
+
) -> None:
|
109 |
+
"""Train the model on the provided dataset"""
|
110 |
+
import joblib
|
111 |
+
|
112 |
+
from app.constants import MODELS_DIR
|
113 |
+
from app.model import create_model, load_data, train_model
|
114 |
+
|
115 |
+
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
116 |
+
if model_path.exists():
|
117 |
+
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
118 |
+
|
119 |
+
click.echo("Preprocessing dataset... ", nl=False)
|
120 |
+
text_data, label_data = load_data(dataset)
|
121 |
+
click.echo(DONE_STR)
|
122 |
+
|
123 |
+
click.echo("Creating model... ", nl=False)
|
124 |
+
model = create_model(max_features, seed=None if seed == -1 else seed)
|
125 |
+
click.echo(DONE_STR)
|
126 |
+
|
127 |
+
click.echo("Training model... ", nl=False)
|
128 |
+
accuracy = train_model(model, text_data, label_data)
|
129 |
+
joblib.dump(model, model_path)
|
130 |
+
click.echo(DONE_STR)
|
131 |
+
|
132 |
+
click.echo("Model accuracy: ")
|
133 |
+
click.secho(f"{accuracy:.2%}", fg="blue")
|
134 |
+
|
135 |
+
# TODO: Add hyperparameter options
|
136 |
+
# TODO: Random/grid search for finding best classifier and hyperparameters
|
137 |
+
|
138 |
+
|
139 |
+
def cli_wrapper() -> None:
|
140 |
+
cli(max_content_width=120)
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
cli_wrapper()
|
app/constants.py
CHANGED
@@ -1,16 +1,32 @@
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
MODELS_DIR: Path = Path("models")
|
10 |
-
CACHE_DIR: Path = Path("cache")
|
11 |
-
CHECKPOINT_PATH: Path = CACHE_DIR / "pipeline.pkl"
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
from pathlib import Path
|
5 |
|
6 |
+
CACHE_DIR = Path(os.getenv("CACHE_DIR", ".cache"))
|
7 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
|
8 |
+
MODELS_DIR = Path(os.getenv("MODELS_DIR", "models"))
|
9 |
+
|
10 |
+
SENTIMENT140_PATH = DATA_DIR / "sentiment140.csv"
|
11 |
+
SENTIMENT140_URL = "https://www.kaggle.com/datasets/kazanova/sentiment140"
|
12 |
+
|
13 |
+
AMAZONREVIEWS_PATH = (DATA_DIR / "amazonreviews.test.txt.bz2", DATA_DIR / "amazonreviews.train.txt.bz2")
|
14 |
+
AMAZONREVIEWS_URL = "https://www.kaggle.com/datasets/bittlingmayer/amazonreviews"
|
15 |
|
16 |
+
IMDB50K_PATH = DATA_DIR / "imdb50k.csv"
|
17 |
+
IMDB50K_URL = "https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews"
|
|
|
|
|
|
|
18 |
|
19 |
+
URL_REGEX = r"(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?\/[a-zA-Z0-9]{2,}|((https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?)|(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}(\.[a-zA-Z0-9]{2,})?" # https://www.freecodecamp.org/news/how-to-write-a-regular-expression-for-a-url/
|
20 |
+
EMOTICON_MAP = {
|
21 |
+
"SMILE": [":)", ":-)", ": )", ":D", ":-D", ": D", ";)", ";-)", "; )", ":>", ":->", ": >", ":]", ":-]", ": ]"],
|
22 |
+
"LOVE": ["<3", ":*", ":-*", ": *"],
|
23 |
+
"WINK": [";)", ";-)", "; )", ";>", ";->", "; >"],
|
24 |
+
"FROWN": [":(", ":-(", ": (", ":[", ":-[", ": ["],
|
25 |
+
"CRY": [":'(", ": (", ":' (", ":'[", ":' ["],
|
26 |
+
"SURPRISE": [":O", ":-O", ": O", ":0", ":-0", ": 0", ":o", ":-o", ": o"],
|
27 |
+
"ANGRY": [">:(", ">:-(", "> :(", ">:["],
|
28 |
+
}
|
29 |
|
30 |
+
CACHE_DIR.mkdir(exist_ok=True, parents=True)
|
31 |
+
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
32 |
+
MODELS_DIR.mkdir(exist_ok=True, parents=True)
|
app/gui.py
CHANGED
@@ -1,92 +1,58 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
|
|
|
|
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
-
|
8 |
-
from
|
9 |
|
10 |
-
|
11 |
-
TOKENIZER_EXT = ".tokenizer.pkl"
|
12 |
-
MODEL_EXT = ".model.pkl"
|
13 |
-
POSITIVE_LABEL = "Positive 😊"
|
14 |
-
NEGATIVE_LABEL = "Negative 😤"
|
15 |
-
REFRESH_SYMBOL = "🔄"
|
16 |
-
|
17 |
-
|
18 |
-
def load_style() -> str:
|
19 |
-
if not CSS_PATH.is_file():
|
20 |
-
return ""
|
21 |
-
|
22 |
-
with Path.open(CSS_PATH) as f:
|
23 |
-
return f.read()
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
return POSITIVE_LABEL if pred else NEGATIVE_LABEL
|
30 |
-
|
31 |
-
|
32 |
-
def train_wrapper() -> None:
|
33 |
-
msg = "Training is not supported in the GUI."
|
34 |
-
raise NotImplementedError(msg)
|
35 |
-
|
36 |
|
37 |
-
def evaluate_wrapper() -> None:
|
38 |
-
msg = "Evaluation is not supported in the GUI."
|
39 |
-
raise NotImplementedError(msg)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
with gr.Blocks(css=load_style()) as demo:
|
43 |
-
gr.Markdown("## Sentiment Analysis")
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
placeholder="Enter text here",
|
50 |
-
key="input-textbox",
|
51 |
-
)
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
-
with gr.Row(elem_classes="justify-between"):
|
57 |
-
clear_btn = gr.ClearButton([textbox, output], value="Clear 🧹")
|
58 |
-
analyze_btn = gr.Button(
|
59 |
-
"Analyze 🔍",
|
60 |
-
variant="primary",
|
61 |
-
interactive=False,
|
62 |
-
)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
model_selector = gr.Dropdown(
|
72 |
-
choices=[mdl.stem[: -len(".model")] for mdl in MODELS_DIR.glob(f"*{MODEL_EXT}")],
|
73 |
-
label="Model",
|
74 |
-
key="model-selector",
|
75 |
-
)
|
76 |
|
77 |
-
|
|
|
|
|
|
|
78 |
|
79 |
-
# Event handlers
|
80 |
-
textbox.input(
|
81 |
-
fn=lambda text: gr.update(interactive=bool(text.strip())),
|
82 |
-
inputs=[textbox],
|
83 |
-
outputs=[analyze_btn],
|
84 |
-
)
|
85 |
-
analyze_btn.click(
|
86 |
-
fn=predict_wrapper,
|
87 |
-
inputs=[textbox, tokenizer_selector, model_selector],
|
88 |
-
outputs=[output],
|
89 |
-
)
|
90 |
|
91 |
-
|
92 |
-
demo.launch()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import TYPE_CHECKING
|
6 |
|
7 |
import gradio as gr
|
8 |
+
import joblib
|
9 |
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from sklearn.pipeline import Pipeline
|
12 |
|
13 |
+
__all__ = ["launch_gui"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
+
POSITIVE_LABEL = "Positive 😊"
|
17 |
+
NEUTRAL_LABEL = "Neutral 😐"
|
18 |
+
NEGATIVE_LABEL = "Negative 😤"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
20 |
|
21 |
+
@lru_cache(maxsize=1)
|
22 |
+
def load_model() -> Pipeline:
|
23 |
+
"""Load the trained model and cache it."""
|
24 |
+
model_path = os.environ.get("MODEL_PATH", None)
|
25 |
+
if model_path is None:
|
26 |
+
msg = "MODEL_PATH environment variable not set"
|
27 |
+
raise ValueError(msg)
|
28 |
+
return joblib.load(model_path)
|
29 |
|
|
|
|
|
30 |
|
31 |
+
def sentiment_analysis(text: str) -> str:
|
32 |
+
"""Perform sentiment analysis on the provided text."""
|
33 |
+
model = load_model()
|
34 |
+
prediction = model.predict([text])[0]
|
|
|
|
|
|
|
35 |
|
36 |
+
if prediction == 0:
|
37 |
+
return NEGATIVE_LABEL
|
38 |
+
if prediction == 1:
|
39 |
+
return POSITIVE_LABEL
|
40 |
+
return NEUTRAL_LABEL
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
demo = gr.Interface(
|
44 |
+
fn=sentiment_analysis,
|
45 |
+
inputs="text",
|
46 |
+
outputs="label",
|
47 |
+
title="Sentiment Analysis",
|
48 |
+
)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
def launch_gui(model_path: str, share: bool) -> None:
|
52 |
+
"""Launch the Gradio GUI."""
|
53 |
+
os.environ["MODEL_PATH"] = model_path
|
54 |
+
demo.launch(share=share)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
if __name__ == "__main__":
|
58 |
+
demo.launch()
|
app/model.py
CHANGED
@@ -1,144 +1,304 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
|
|
3 |
import warnings
|
4 |
-
from
|
5 |
-
from typing import TYPE_CHECKING, Sequence
|
6 |
|
7 |
-
import
|
|
|
|
|
|
|
8 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
9 |
from sklearn.linear_model import LogisticRegression
|
|
|
10 |
from sklearn.pipeline import Pipeline
|
11 |
|
12 |
-
from constants import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
|
|
|
|
34 |
|
35 |
-
def export_to_file(pipeline: Pipeline, path: Path) -> None:
|
36 |
-
joblib.dump(pipeline, path)
|
37 |
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
return tokenizer.transform([text])[0]
|
42 |
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
-
def
|
51 |
-
|
52 |
-
|
53 |
-
y: list[int],
|
54 |
-
export_path: Path,
|
55 |
-
cache: joblib.Memory,
|
56 |
) -> Pipeline:
|
57 |
-
|
58 |
-
|
59 |
-
with warnings.catch_warnings():
|
60 |
-
warnings.simplefilter("ignore")
|
61 |
-
pipeline.fit(x, y)
|
62 |
-
|
63 |
-
export_to_file(pipeline, export_path)
|
64 |
-
return pipeline
|
65 |
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
[
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
)
|
|
|
74 |
("tfidf", TfidfTransformer()),
|
|
|
|
|
75 |
],
|
76 |
-
|
77 |
-
y,
|
78 |
-
export_path,
|
79 |
-
cache,
|
80 |
)
|
81 |
|
82 |
|
83 |
-
def
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
model = train_model(x_tr, y, cache, rs)
|
107 |
-
|
108 |
-
return Pipeline([("tokenizer", tokenizer), ("model", model)])
|
109 |
-
|
110 |
-
|
111 |
-
def train_tokenizer(x: list[str], y: list[int], cache: joblib.Memory) -> Pipeline:
|
112 |
-
# TODO: In the future, allow for different tokenizers
|
113 |
-
pipeline = Pipeline(
|
114 |
-
[
|
115 |
-
(
|
116 |
-
"vectorize",
|
117 |
-
CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_TOKENIZER_FEATURES),
|
118 |
-
),
|
119 |
-
("tfidf", TfidfTransformer()),
|
120 |
-
],
|
121 |
-
memory=cache,
|
122 |
)
|
123 |
|
124 |
with warnings.catch_warnings():
|
125 |
-
warnings.simplefilter("ignore")
|
126 |
-
|
127 |
-
|
128 |
-
return pipeline
|
129 |
-
|
130 |
-
|
131 |
-
def train_model(x: list[str], y: list[int], cache: joblib.Memory, rs: RandomState) -> Pipeline:
|
132 |
-
# TODO: In the future, allow for different classifiers
|
133 |
-
pipeline = Pipeline(
|
134 |
-
[
|
135 |
-
("clf", LogisticRegression(max_iter=CLF_MAX_ITER, random_state=rs)),
|
136 |
-
],
|
137 |
-
memory=cache,
|
138 |
-
)
|
139 |
-
|
140 |
-
with warnings.catch_warnings():
|
141 |
-
warnings.simplefilter("ignore") # Ignore joblib warnings
|
142 |
-
pipeline.fit(x, y)
|
143 |
|
144 |
-
return
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import bz2
|
4 |
+
import re
|
5 |
import warnings
|
6 |
+
from typing import Literal
|
|
|
7 |
|
8 |
+
import pandas as pd
|
9 |
+
from joblib import Memory
|
10 |
+
from nltk.stem import WordNetLemmatizer
|
11 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
12 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
13 |
from sklearn.linear_model import LogisticRegression
|
14 |
+
from sklearn.model_selection import train_test_split
|
15 |
from sklearn.pipeline import Pipeline
|
16 |
|
17 |
+
from app.constants import (
|
18 |
+
AMAZONREVIEWS_PATH,
|
19 |
+
AMAZONREVIEWS_URL,
|
20 |
+
CACHE_DIR,
|
21 |
+
EMOTICON_MAP,
|
22 |
+
IMDB50K_PATH,
|
23 |
+
IMDB50K_URL,
|
24 |
+
SENTIMENT140_PATH,
|
25 |
+
SENTIMENT140_URL,
|
26 |
+
URL_REGEX,
|
27 |
+
)
|
28 |
+
|
29 |
+
__all__ = ["load_data", "create_model", "train_model"]
|
30 |
+
|
31 |
+
|
32 |
+
class TextCleaner(BaseEstimator, TransformerMixin):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
*,
|
36 |
+
replace_url: bool = True,
|
37 |
+
replace_hashtag: bool = True,
|
38 |
+
replace_emoticon: bool = True,
|
39 |
+
replace_emoji: bool = True,
|
40 |
+
lowercase: bool = True,
|
41 |
+
character_threshold: int = 2,
|
42 |
+
remove_special_characters: bool = True,
|
43 |
+
remove_extra_spaces: bool = True,
|
44 |
+
):
|
45 |
+
self.replace_url = replace_url
|
46 |
+
self.replace_hashtag = replace_hashtag
|
47 |
+
self.replace_emoticon = replace_emoticon
|
48 |
+
self.replace_emoji = replace_emoji
|
49 |
+
self.lowercase = lowercase
|
50 |
+
self.character_threshold = character_threshold
|
51 |
+
self.remove_special_characters = remove_special_characters
|
52 |
+
self.remove_extra_spaces = remove_extra_spaces
|
53 |
+
|
54 |
+
def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextCleaner:
|
55 |
+
return self
|
56 |
+
|
57 |
+
def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
|
58 |
+
# Replace URLs, hashtags, emoticons, and emojis
|
59 |
+
data = [re.sub(URL_REGEX, "URL", text) for text in data] if self.replace_url else data
|
60 |
+
data = [re.sub(r"#\w+", "HASHTAG", text) for text in data] if self.replace_hashtag else data
|
61 |
+
|
62 |
+
# Replace emoticons
|
63 |
+
if self.replace_emoticon:
|
64 |
+
for word, emoticons in EMOTICON_MAP.items():
|
65 |
+
for emoticon in emoticons:
|
66 |
+
data = [text.replace(emoticon, f"EMOTE_{word}") for text in data]
|
67 |
+
|
68 |
+
# Basic text cleaning
|
69 |
+
data = [text.lower() for text in data] if self.lowercase else data # Lowercase
|
70 |
+
threshold_pattern = re.compile(rf"\b\w{{1,{self.character_threshold}}}\b")
|
71 |
+
data = (
|
72 |
+
[re.sub(threshold_pattern, "", text) for text in data] if self.character_threshold > 0 else data
|
73 |
+
) # Remove short words
|
74 |
+
data = (
|
75 |
+
[re.sub(r"[^a-zA-Z0-9\s]", "", text) for text in data] if self.remove_special_characters else data
|
76 |
+
) # Remove special characters
|
77 |
+
data = [re.sub(r"\s+", " ", text) for text in data] if self.remove_extra_spaces else data # Remove extra spaces
|
78 |
+
|
79 |
+
# Remove leading and trailing whitespace
|
80 |
+
return [text.strip() for text in data]
|
81 |
+
|
82 |
+
|
83 |
+
class TextLemmatizer(BaseEstimator, TransformerMixin):
|
84 |
+
def __init__(self):
|
85 |
+
self.lemmatizer = WordNetLemmatizer()
|
86 |
+
|
87 |
+
def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextLemmatizer:
|
88 |
+
return self
|
89 |
+
|
90 |
+
def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
|
91 |
+
return [self.lemmatizer.lemmatize(text) for text in data]
|
92 |
+
|
93 |
+
|
94 |
+
def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
|
95 |
+
"""Load the sentiment140 dataset and make it suitable for use.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
include_neutral: Whether to include neutral sentiment
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Text and label data
|
102 |
+
|
103 |
+
Raises:
|
104 |
+
FileNotFoundError: If the dataset is not found
|
105 |
+
"""
|
106 |
+
# Check if the dataset exists
|
107 |
+
if not SENTIMENT140_PATH.exists():
|
108 |
+
msg = (
|
109 |
+
f"Sentiment140 dataset not found at: '{SENTIMENT140_PATH}'\n"
|
110 |
+
"Please download the dataset from:\n"
|
111 |
+
f"{SENTIMENT140_URL}"
|
112 |
+
)
|
113 |
+
raise FileNotFoundError(msg)
|
114 |
+
|
115 |
+
# Load the dataset
|
116 |
+
data = pd.read_csv(
|
117 |
+
SENTIMENT140_PATH,
|
118 |
+
encoding="ISO-8859-1",
|
119 |
+
names=[
|
120 |
+
"target", # 0 = negative, 2 = neutral, 4 = positive
|
121 |
+
"id", # The id of the tweet
|
122 |
+
"date", # The date of the tweet
|
123 |
+
"flag", # The query, NO_QUERY if not present
|
124 |
+
"user", # The user that tweeted
|
125 |
+
"text", # The text of the tweet
|
126 |
+
],
|
127 |
+
)
|
128 |
|
129 |
+
# Ignore rows with neutral sentiment
|
130 |
+
if not include_neutral:
|
131 |
+
data = data[data["target"] != 2]
|
132 |
+
|
133 |
+
# Map sentiment values
|
134 |
+
data["sentiment"] = data["target"].map(
|
135 |
+
{
|
136 |
+
0: 0, # Negative
|
137 |
+
4: 1, # Positive
|
138 |
+
2: 2, # Neutral
|
139 |
+
},
|
140 |
+
)
|
141 |
|
142 |
+
# Return as lists
|
143 |
+
return data["text"].tolist(), data["sentiment"].tolist()
|
144 |
+
|
145 |
+
|
146 |
+
def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
|
147 |
+
"""Load the amazonreviews dataset and make it suitable for use.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
merge: Whether to merge the test and train datasets (otherwise ignore test)
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Text and label data
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
FileNotFoundError: If the dataset is not found
|
157 |
+
"""
|
158 |
+
# Check if the dataset exists
|
159 |
+
test_exists = AMAZONREVIEWS_PATH[0].exists() or not merge
|
160 |
+
train_exists = AMAZONREVIEWS_PATH[1].exists()
|
161 |
+
if not (test_exists and train_exists):
|
162 |
+
msg = (
|
163 |
+
f"Amazonreviews dataset not found at: '{AMAZONREVIEWS_PATH[0]}' and '{AMAZONREVIEWS_PATH[1]}'\n"
|
164 |
+
"Please download the dataset from:\n"
|
165 |
+
f"{AMAZONREVIEWS_URL}"
|
166 |
+
)
|
167 |
+
raise FileNotFoundError(msg)
|
168 |
+
|
169 |
+
# Load the datasets
|
170 |
+
with bz2.BZ2File(AMAZONREVIEWS_PATH[1]) as train_file:
|
171 |
+
train_data = [line.decode("utf-8") for line in train_file]
|
172 |
+
|
173 |
+
test_data = []
|
174 |
+
if merge:
|
175 |
+
with bz2.BZ2File(AMAZONREVIEWS_PATH[0]) as test_file:
|
176 |
+
test_data = [line.decode("utf-8") for line in test_file]
|
177 |
+
|
178 |
+
# Merge the datasets
|
179 |
+
data = train_data + test_data
|
180 |
+
|
181 |
+
# Split the data into labels and text
|
182 |
+
labels, texts = zip(*(line.split(" ", 1) for line in data))
|
183 |
+
|
184 |
+
# Map sentiment values
|
185 |
+
sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
|
186 |
+
|
187 |
+
# Return as lists
|
188 |
+
return texts, sentiments
|
189 |
+
|
190 |
+
|
191 |
+
def load_imdb50k() -> tuple[list[str], list[int]]:
|
192 |
+
"""Load the imdb50k dataset and make it suitable for use.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
Text and label data
|
196 |
+
|
197 |
+
Raises:
|
198 |
+
FileNotFoundError: If the dataset is not found
|
199 |
+
"""
|
200 |
+
# Check if the dataset exists
|
201 |
+
if not IMDB50K_PATH.exists():
|
202 |
+
msg = (
|
203 |
+
f"IMDB50K dataset not found at: '{IMDB50K_PATH}'\n"
|
204 |
+
"Please download the dataset from:\n"
|
205 |
+
f"{IMDB50K_URL}"
|
206 |
+
) # fmt: off
|
207 |
+
raise FileNotFoundError(msg)
|
208 |
+
|
209 |
+
# Load the dataset
|
210 |
+
data = pd.read_csv(IMDB50K_PATH)
|
211 |
+
|
212 |
+
# Map sentiment values
|
213 |
+
data["sentiment"] = data["sentiment"].map(
|
214 |
+
{
|
215 |
+
"positive": 1,
|
216 |
+
"negative": 0,
|
217 |
+
},
|
218 |
+
)
|
219 |
|
220 |
+
# Return as lists
|
221 |
+
return data["review"].tolist(), data["sentiment"].tolist()
|
222 |
|
|
|
|
|
223 |
|
224 |
+
def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]:
|
225 |
+
"""Load and preprocess the specified dataset.
|
226 |
|
227 |
+
Args:
|
228 |
+
dataset: Dataset to load
|
|
|
229 |
|
230 |
+
Returns:
|
231 |
+
Text and label data
|
232 |
|
233 |
+
Raises:
|
234 |
+
ValueError: If the dataset is not recognized
|
235 |
+
"""
|
236 |
+
match dataset:
|
237 |
+
case "sentiment140":
|
238 |
+
return load_sentiment140(include_neutral=False)
|
239 |
+
case "amazonreviews":
|
240 |
+
return load_amazonreviews(merge=True)
|
241 |
+
case "imdb50k":
|
242 |
+
return load_imdb50k()
|
243 |
+
case _:
|
244 |
+
msg = f"Unknown dataset: {dataset}"
|
245 |
+
raise ValueError(msg)
|
246 |
|
247 |
|
248 |
+
def create_model(
|
249 |
+
max_features: int,
|
250 |
+
seed: int | None = None,
|
|
|
|
|
|
|
251 |
) -> Pipeline:
|
252 |
+
"""Create a sentiment analysis model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
Args:
|
255 |
+
max_features: Maximum number of features
|
256 |
+
seed: Random seed (None for random seed)
|
257 |
|
258 |
+
Returns:
|
259 |
+
Untrained model
|
260 |
+
"""
|
261 |
+
return Pipeline(
|
262 |
[
|
263 |
+
# Text preprocessing
|
264 |
+
("clean", TextCleaner()),
|
265 |
+
("lemma", TextLemmatizer()),
|
266 |
+
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
267 |
+
("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=max_features)),
|
268 |
("tfidf", TfidfTransformer()),
|
269 |
+
# Classifier
|
270 |
+
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
271 |
],
|
272 |
+
memory=Memory(CACHE_DIR, verbose=0),
|
|
|
|
|
|
|
273 |
)
|
274 |
|
275 |
|
276 |
+
def train_model(
|
277 |
+
model: Pipeline,
|
278 |
+
text_data: list[str],
|
279 |
+
label_data: list[int],
|
280 |
+
seed: int = 42,
|
281 |
+
) -> float:
|
282 |
+
"""Train the sentiment analysis model.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
model: Untrained model
|
286 |
+
text_data: Text data
|
287 |
+
label_data: Label data
|
288 |
+
seed: Random seed (None for random seed)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
Accuracy score
|
292 |
+
"""
|
293 |
+
text_train, text_test, label_train, label_test = train_test_split(
|
294 |
+
text_data,
|
295 |
+
label_data,
|
296 |
+
test_size=0.2,
|
297 |
+
random_state=seed,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
)
|
299 |
|
300 |
with warnings.catch_warnings():
|
301 |
+
warnings.simplefilter("ignore")
|
302 |
+
model.fit(text_train, label_train)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
return model.score(text_test, label_test)
|
app/utils.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
"""Utility functions"""
|
2 |
-
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import itertools
|
6 |
-
import re
|
7 |
-
import warnings
|
8 |
-
from collections import deque
|
9 |
-
from enum import Enum
|
10 |
-
from functools import lru_cache
|
11 |
-
from threading import Event, Lock
|
12 |
-
from typing import Any
|
13 |
-
|
14 |
-
from joblib import Memory
|
15 |
-
from numpy.random import RandomState
|
16 |
-
|
17 |
-
from constants import CACHE_DIR, DEFAULT_SEED
|
18 |
-
|
19 |
-
__all__ = ["colorize", "wrap_queued_call", "get_random_state", "get_cache_memory"]
|
20 |
-
|
21 |
-
|
22 |
-
ANSI_RESET = 0
|
23 |
-
|
24 |
-
|
25 |
-
class Color(Enum):
|
26 |
-
"""ANSI color codes."""
|
27 |
-
|
28 |
-
BLACK = 30
|
29 |
-
RED = 31
|
30 |
-
GREEN = 32
|
31 |
-
YELLOW = 33
|
32 |
-
BLUE = 34
|
33 |
-
MAGENTA = 35
|
34 |
-
CYAN = 36
|
35 |
-
WHITE = 37
|
36 |
-
|
37 |
-
|
38 |
-
class Style(Enum):
|
39 |
-
"""ANSI style codes."""
|
40 |
-
|
41 |
-
BOLD = 1
|
42 |
-
DIM = 2
|
43 |
-
ITALIC = 3
|
44 |
-
UNDERLINE = 4
|
45 |
-
BLINK = 5
|
46 |
-
INVERTED = 7
|
47 |
-
HIDDEN = 8
|
48 |
-
|
49 |
-
|
50 |
-
# https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
51 |
-
class FIFOLock:
|
52 |
-
def __init__(self):
|
53 |
-
self._lock = Lock()
|
54 |
-
self._inner_lock = Lock()
|
55 |
-
self._pending_threads = deque()
|
56 |
-
|
57 |
-
def acquire(self, blocking: bool = True) -> bool:
|
58 |
-
with self._inner_lock:
|
59 |
-
lock_acquired = self._lock.acquire(False)
|
60 |
-
if lock_acquired:
|
61 |
-
return True
|
62 |
-
if not blocking:
|
63 |
-
return False
|
64 |
-
|
65 |
-
release_event = Event()
|
66 |
-
self._pending_threads.append(release_event)
|
67 |
-
|
68 |
-
release_event.wait()
|
69 |
-
return self._lock.acquire()
|
70 |
-
|
71 |
-
def release(self) -> None:
|
72 |
-
with self._inner_lock:
|
73 |
-
if self._pending_threads:
|
74 |
-
release_event = self._pending_threads.popleft()
|
75 |
-
release_event.set()
|
76 |
-
|
77 |
-
self._lock.release()
|
78 |
-
|
79 |
-
__enter__ = acquire
|
80 |
-
|
81 |
-
def __exit__(self, _t, _v, _tb): # noqa: ANN001
|
82 |
-
self.release()
|
83 |
-
|
84 |
-
|
85 |
-
@lru_cache(maxsize=1)
|
86 |
-
def get_queue_lock() -> FIFOLock:
|
87 |
-
return FIFOLock()
|
88 |
-
|
89 |
-
|
90 |
-
@lru_cache(maxsize=1)
|
91 |
-
def get_random_state(seed: int = DEFAULT_SEED) -> RandomState:
|
92 |
-
return RandomState(seed)
|
93 |
-
|
94 |
-
|
95 |
-
@lru_cache(maxsize=1)
|
96 |
-
def get_cache_memory() -> Memory:
|
97 |
-
return Memory(CACHE_DIR, verbose=0)
|
98 |
-
|
99 |
-
|
100 |
-
def to_ansi(code: int) -> str:
|
101 |
-
"""Convert an integer to an ANSI escape code."""
|
102 |
-
return f"\033[{code}m"
|
103 |
-
|
104 |
-
|
105 |
-
@lru_cache(maxsize=None)
|
106 |
-
def get_ansi_color(color: Color, bright: bool = False, background: bool = False) -> str:
|
107 |
-
"""Get ANSI color code for the specified color, brightness and background."""
|
108 |
-
code = color.value
|
109 |
-
if bright:
|
110 |
-
code += 60
|
111 |
-
if background:
|
112 |
-
code += 10
|
113 |
-
return to_ansi(code)
|
114 |
-
|
115 |
-
|
116 |
-
def replace_color_tag(color: Color, text: str) -> None:
|
117 |
-
"""Replace both dark and light color tags for background and foreground."""
|
118 |
-
for bright, bg in itertools.product([False, True], repeat=2):
|
119 |
-
tag = f"{'BG_' if bg else ''}{'BRIGHT_' if bright else ''}{color.name}"
|
120 |
-
text = text.replace(f"[{tag}]", get_ansi_color(color, bright=bright, background=bg))
|
121 |
-
text = text.replace(f"[/{tag}]", to_ansi(ANSI_RESET))
|
122 |
-
|
123 |
-
return text
|
124 |
-
|
125 |
-
|
126 |
-
@lru_cache(maxsize=256)
|
127 |
-
def colorize(text: str, strip: bool = True) -> str:
|
128 |
-
"""Format text with ANSI color codes using tags [COLOR], [BG_COLOR] and [STYLE].
|
129 |
-
Reset color/style with [/TAG].
|
130 |
-
Escape with double brackets [[]]. Strip leading and trailing whitespace if strip=True.
|
131 |
-
"""
|
132 |
-
|
133 |
-
# replace foreground and background color tags
|
134 |
-
for color in Color:
|
135 |
-
text = replace_color_tag(color, text)
|
136 |
-
|
137 |
-
# replace style tags
|
138 |
-
for style in Style:
|
139 |
-
text = text.replace(f"[{style.name}]", to_ansi(style.value)).replace(f"[/{style.name}]", to_ansi(ANSI_RESET))
|
140 |
-
|
141 |
-
# if there are any tags left, remove them and throw a warning
|
142 |
-
pat1 = re.compile(r"((?<!\[)\[)([^\[\]]*)(\](?!\]))")
|
143 |
-
for match in pat1.finditer(text):
|
144 |
-
color = match.group(1)
|
145 |
-
text = text.replace(match.group(0), "")
|
146 |
-
warnings.warn(f"Invalid color tag: {color!r}", UserWarning, stacklevel=2)
|
147 |
-
|
148 |
-
# escape double brackets
|
149 |
-
pat2 = re.compile(r"\[\[[^\[\]\v]+\]\]")
|
150 |
-
text = pat2.sub("", text)
|
151 |
-
|
152 |
-
# reset color/style at the end
|
153 |
-
text += to_ansi(ANSI_RESET)
|
154 |
-
|
155 |
-
return text.strip() if strip else text
|
156 |
-
|
157 |
-
|
158 |
-
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/modules/call_queue.py
|
159 |
-
def wrap_queued_call(func: callable) -> callable:
|
160 |
-
def f(*args, **kwargs) -> Any: # noqa: ANN003, ANN002
|
161 |
-
with get_queue_lock():
|
162 |
-
return func(*args, **kwargs)
|
163 |
-
|
164 |
-
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated/__init__.py
DELETED
File without changes
|
deprecated/main.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import click
|
6 |
-
import joblib
|
7 |
-
|
8 |
-
from app.utils import colorize
|
9 |
-
|
10 |
-
|
11 |
-
@click.group()
|
12 |
-
def cli() -> None: ...
|
13 |
-
|
14 |
-
|
15 |
-
@cli.command("predict")
|
16 |
-
@click.option(
|
17 |
-
"-m",
|
18 |
-
"--model",
|
19 |
-
"model_path",
|
20 |
-
default="models/model.pkl",
|
21 |
-
help="Path to the model file.",
|
22 |
-
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
|
23 |
-
)
|
24 |
-
@click.argument("text", nargs=-1)
|
25 |
-
def predict(model_path: Path, text: list[str]) -> None:
|
26 |
-
input_text = " ".join(text).strip()
|
27 |
-
if not input_text:
|
28 |
-
click.echo("[RED]Error[/RED]: Input text is empty.")
|
29 |
-
return
|
30 |
-
|
31 |
-
# Load the model
|
32 |
-
click.echo("Loading model... ", nl=False)
|
33 |
-
model = joblib.load(model_path)
|
34 |
-
click.echo(colorize("[GREEN]DONE"))
|
35 |
-
|
36 |
-
# Run the model
|
37 |
-
click.echo("Performing sentiment analysis... ", nl=False)
|
38 |
-
prediction = model.predict([input_text])
|
39 |
-
sentiment = "[GREEN]POSITIVE" if prediction[0] == 1 else "[RED]NEGATIVE"
|
40 |
-
click.echo(colorize(sentiment))
|
41 |
-
|
42 |
-
|
43 |
-
if __name__ == "__main__":
|
44 |
-
cli()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated/train.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import warnings
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import TYPE_CHECKING
|
6 |
-
|
7 |
-
import click
|
8 |
-
import joblib
|
9 |
-
import pandas as pd
|
10 |
-
from numpy.random import RandomState
|
11 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
-
from sklearn.linear_model import LogisticRegression
|
13 |
-
from sklearn.metrics import accuracy_score, classification_report
|
14 |
-
from sklearn.model_selection import train_test_split
|
15 |
-
from sklearn.pipeline import Pipeline
|
16 |
-
|
17 |
-
if TYPE_CHECKING:
|
18 |
-
from sklearn.base import BaseEstimator
|
19 |
-
|
20 |
-
SEED = 42
|
21 |
-
DATASET_PATH = Path("data/training.1600000.processed.noemoticon.csv")
|
22 |
-
STOPWORDS_PATH = Path("data/stopwords-en.txt")
|
23 |
-
CHECKPOINT_PATH = Path("cache/pipeline.pkl")
|
24 |
-
MODELS_DIR = Path("models")
|
25 |
-
CACHE_DIR = Path("cache")
|
26 |
-
MAX_FEATURES = 10000 # 500000
|
27 |
-
|
28 |
-
# Make sure paths exist
|
29 |
-
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
30 |
-
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
31 |
-
|
32 |
-
# Memory cache for sklearn pipelines
|
33 |
-
mem = joblib.Memory(CACHE_DIR, verbose=0)
|
34 |
-
|
35 |
-
# TODO: use xgboost
|
36 |
-
|
37 |
-
|
38 |
-
def get_random_state(seed: int = SEED) -> RandomState:
|
39 |
-
return RandomState(seed)
|
40 |
-
|
41 |
-
|
42 |
-
def load_data() -> tuple[list[str], list[int]]:
|
43 |
-
"""The model takes in a list of strings and a list of integers where 1 is positive sentiment and 0 is negative sentiment."""
|
44 |
-
data = pd.read_csv(
|
45 |
-
DATASET_PATH,
|
46 |
-
encoding="ISO-8859-1",
|
47 |
-
names=[
|
48 |
-
"target", # 0 = negative, 2 = neutral, 4 = positive
|
49 |
-
"id", # The id of the tweet
|
50 |
-
"date", # The date of the tweet
|
51 |
-
"flag", # The query, NO_QUERY if not present
|
52 |
-
"user", # The user that tweeted
|
53 |
-
"text", # The text of the tweet
|
54 |
-
],
|
55 |
-
)
|
56 |
-
|
57 |
-
# Ignore rows with neutral sentiment
|
58 |
-
data = data[data["target"] != 2]
|
59 |
-
|
60 |
-
# Create new column called "sentiment" with 1 for positive and 0 for negative
|
61 |
-
data["sentiment"] = data["target"] == 4
|
62 |
-
|
63 |
-
# Drop the columns we don't need
|
64 |
-
# data = data.drop(columns=["target", "id", "date", "flag", "user"]) # NOTE: No need, since we return the columns we need
|
65 |
-
|
66 |
-
# Return as lists
|
67 |
-
return list(data["text"]), list(data["sentiment"])
|
68 |
-
|
69 |
-
|
70 |
-
def create_pipeline(clf: BaseEstimator) -> Pipeline:
|
71 |
-
return Pipeline(
|
72 |
-
[
|
73 |
-
# Preprocess
|
74 |
-
# ("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_FEATURES)),
|
75 |
-
# ("tfidf", TfidfTransformer()),
|
76 |
-
("vectorize", TfidfVectorizer(ngram_range=(1, 2), max_features=MAX_FEATURES)),
|
77 |
-
# Classifier
|
78 |
-
("clf", clf),
|
79 |
-
],
|
80 |
-
memory=mem,
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
def evaluate_pipeline(pipeline: Pipeline, x: list[str], y: list[int]) -> float:
|
85 |
-
y_pred = pipeline.predict(x)
|
86 |
-
report = classification_report(y, y_pred)
|
87 |
-
click.echo(report)
|
88 |
-
|
89 |
-
# TODO: Confusion matrix
|
90 |
-
|
91 |
-
return accuracy_score(y, y_pred)
|
92 |
-
|
93 |
-
|
94 |
-
def export_pipeline(pipeline: Pipeline, name: str) -> None:
|
95 |
-
model_path = MODELS_DIR / f"{name}.pkl"
|
96 |
-
joblib.dump(pipeline, model_path)
|
97 |
-
click.echo(f"Model exported to {model_path!r}")
|
98 |
-
|
99 |
-
|
100 |
-
@click.command()
|
101 |
-
@click.option("--retrain", is_flag=True, help="Train the model even if a checkpoint exists.")
|
102 |
-
@click.option("--evaluate", is_flag=True, help="Evaluate the model.")
|
103 |
-
@click.option("--flush-cache", is_flag=True, help="Clear sklearn cache.")
|
104 |
-
@click.option("--seed", type=int, default=SEED, help="Random seed.")
|
105 |
-
def train(retrain: bool, evaluate: bool, flush_cache: bool, seed: int) -> None:
|
106 |
-
rng = get_random_state(seed)
|
107 |
-
|
108 |
-
# Clear sklearn cache
|
109 |
-
if flush_cache:
|
110 |
-
click.echo("Clearing cache... ", nl=False)
|
111 |
-
mem.clear(warn=False)
|
112 |
-
click.echo("DONE")
|
113 |
-
|
114 |
-
# Load and split data
|
115 |
-
click.echo("Loading data... ", nl=False)
|
116 |
-
x, y = load_data()
|
117 |
-
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=rng)
|
118 |
-
click.echo("DONE")
|
119 |
-
|
120 |
-
# Train model
|
121 |
-
if retrain or not CHECKPOINT_PATH.exists():
|
122 |
-
click.echo("Training model... ", nl=False)
|
123 |
-
clf = LogisticRegression(max_iter=1000, random_state=rng)
|
124 |
-
model = create_pipeline(clf)
|
125 |
-
with warnings.catch_warnings():
|
126 |
-
warnings.simplefilter("ignore") # Ignore joblib warnings
|
127 |
-
model.fit(x_train, y_train)
|
128 |
-
joblib.dump(model, CHECKPOINT_PATH)
|
129 |
-
click.echo("DONE")
|
130 |
-
else:
|
131 |
-
click.echo("Loading model... ", nl=False)
|
132 |
-
model = joblib.load(CHECKPOINT_PATH)
|
133 |
-
click.echo("DONE")
|
134 |
-
|
135 |
-
# Evaluate model
|
136 |
-
if evaluate:
|
137 |
-
evaluate_pipeline(model, x_test, y_test)
|
138 |
-
|
139 |
-
# Quick test
|
140 |
-
test_text = ["I love this movie", "I hate this movie"]
|
141 |
-
click.echo("Quick test:")
|
142 |
-
for text in test_text:
|
143 |
-
click.echo(f"\t{'positive' if model.predict([text])[0] else 'negative'}: {text}")
|
144 |
-
|
145 |
-
# Export model
|
146 |
-
click.echo("Exporting model... ", nl=False)
|
147 |
-
export_pipeline(model, "logistic_regression")
|
148 |
-
click.echo("DONE")
|
149 |
-
|
150 |
-
|
151 |
-
if __name__ == "__main__":
|
152 |
-
train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
justfile
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/env just --justfile
|
2 |
|
3 |
@default:
|
4 |
-
|
5 |
|
6 |
@lint:
|
7 |
poetry run pre-commit run --all-files
|
@@ -16,8 +16,6 @@
|
|
16 |
@requirements:
|
17 |
poetry export -f requirements.txt --output requirements.txt --without dev
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
@gui:
|
23 |
-
poetry run gradio app/gui.py
|
|
|
1 |
#!/usr/bin/env just --justfile
|
2 |
|
3 |
@default:
|
4 |
+
just --list
|
5 |
|
6 |
@lint:
|
7 |
poetry run pre-commit run --all-files
|
|
|
16 |
@requirements:
|
17 |
poetry export -f requirements.txt --output requirements.txt --without dev
|
18 |
|
19 |
+
[no-exit-message]
|
20 |
+
@app *ARGS:
|
21 |
+
poetry run python -m app {{ARGS}}
|
|
|
|
notebook.ipynb
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Sentiment Analysis"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"## Imports"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"from __future__ import annotations\n",
|
24 |
+
"\n",
|
25 |
+
"import re\n",
|
26 |
+
"from functools import cache\n",
|
27 |
+
"\n",
|
28 |
+
"import matplotlib.pyplot as plt\n",
|
29 |
+
"import pandas as pd\n",
|
30 |
+
"import seaborn as sns"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"metadata": {},
|
36 |
+
"source": [
|
37 |
+
"## Load the data"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"data: pd.DataFrame = None # TODO: load dataset\n",
|
47 |
+
"stopwords: set[str] = None # TODO: load stopwords"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "markdown",
|
52 |
+
"metadata": {},
|
53 |
+
"source": [
|
54 |
+
"## Explore the data"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"# Plot the distribution\n",
|
64 |
+
"_, ax = plt.subplots(figsize=(6, 4))\n",
|
65 |
+
"data[\"sentiment\"].value_counts().plot(kind=\"bar\", ax=ax)\n",
|
66 |
+
"ax.set_xticklabels([\"Negative\", \"Positive\"], rotation=0)\n",
|
67 |
+
"ax.set_xlabel(\"Sentiment\")\n",
|
68 |
+
"ax.grid(False)\n",
|
69 |
+
"plt.show()"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"@cache\n",
|
79 |
+
"def extract_words(text: str) -> list[str]:\n",
|
80 |
+
" return re.findall(r\"(\\b[^\\s]+\\b)\", text.lower())"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": null,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"# Extract words and count them\n",
|
90 |
+
"words = data[\"text\"].apply(extract_words).explode()\n",
|
91 |
+
"word_counts = words.value_counts().reset_index()\n",
|
92 |
+
"word_counts.columns = [\"word\", \"count\"]\n",
|
93 |
+
"word_counts.head()"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": null,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"# Plot the most common words\n",
|
103 |
+
"_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
|
104 |
+
"\n",
|
105 |
+
"sns.barplot(data=word_counts.head(10), x=\"count\", y=\"word\", ax=ax1)\n",
|
106 |
+
"ax1.set_title(\"Most common words\")\n",
|
107 |
+
"ax1.grid(False)\n",
|
108 |
+
"ax1.tick_params(axis=\"x\", rotation=45)\n",
|
109 |
+
"\n",
|
110 |
+
"ax2.set_title(\"Most common words (excluding stopwords)\")\n",
|
111 |
+
"sns.barplot(\n",
|
112 |
+
" data=word_counts[~word_counts[\"word\"].isin(stopwords)].head(10),\n",
|
113 |
+
" x=\"count\",\n",
|
114 |
+
" y=\"word\",\n",
|
115 |
+
" ax=ax2,\n",
|
116 |
+
")\n",
|
117 |
+
"ax2.grid(False)\n",
|
118 |
+
"ax2.tick_params(axis=\"x\", rotation=45)\n",
|
119 |
+
"ax2.set_ylabel(\"\")\n",
|
120 |
+
"\n",
|
121 |
+
"plt.show()"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "markdown",
|
126 |
+
"metadata": {},
|
127 |
+
"source": [
|
128 |
+
"## Find best classifier"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "markdown",
|
133 |
+
"metadata": {},
|
134 |
+
"source": [
|
135 |
+
"## Find best hyperparameters"
|
136 |
+
]
|
137 |
+
}
|
138 |
+
],
|
139 |
+
"metadata": {
|
140 |
+
"kernelspec": {
|
141 |
+
"display_name": ".venv",
|
142 |
+
"language": "python",
|
143 |
+
"name": "python3"
|
144 |
+
},
|
145 |
+
"language_info": {
|
146 |
+
"name": "python",
|
147 |
+
"version": "3.12.3"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"nbformat": 4,
|
151 |
+
"nbformat_minor": 2
|
152 |
+
}
|
poetry.lock
CHANGED
@@ -1479,6 +1479,31 @@ files = [
|
|
1479 |
{file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
|
1480 |
]
|
1481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1482 |
[[package]]
|
1483 |
name = "nodeenv"
|
1484 |
version = "1.8.0"
|
@@ -2298,6 +2323,94 @@ files = [
|
|
2298 |
attrs = ">=22.2.0"
|
2299 |
rpds-py = ">=0.7.0"
|
2300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2301 |
[[package]]
|
2302 |
name = "requests"
|
2303 |
version = "2.31.0"
|
@@ -3174,4 +3287,4 @@ files = [
|
|
3174 |
[metadata]
|
3175 |
lock-version = "2.0"
|
3176 |
python-versions = "^3.12"
|
3177 |
-
content-hash = "
|
|
|
1479 |
{file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
|
1480 |
]
|
1481 |
|
1482 |
+
[[package]]
|
1483 |
+
name = "nltk"
|
1484 |
+
version = "3.8.1"
|
1485 |
+
description = "Natural Language Toolkit"
|
1486 |
+
optional = false
|
1487 |
+
python-versions = ">=3.7"
|
1488 |
+
files = [
|
1489 |
+
{file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"},
|
1490 |
+
{file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"},
|
1491 |
+
]
|
1492 |
+
|
1493 |
+
[package.dependencies]
|
1494 |
+
click = "*"
|
1495 |
+
joblib = "*"
|
1496 |
+
regex = ">=2021.8.3"
|
1497 |
+
tqdm = "*"
|
1498 |
+
|
1499 |
+
[package.extras]
|
1500 |
+
all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"]
|
1501 |
+
corenlp = ["requests"]
|
1502 |
+
machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"]
|
1503 |
+
plot = ["matplotlib"]
|
1504 |
+
tgrep = ["pyparsing"]
|
1505 |
+
twitter = ["twython"]
|
1506 |
+
|
1507 |
[[package]]
|
1508 |
name = "nodeenv"
|
1509 |
version = "1.8.0"
|
|
|
2323 |
attrs = ">=22.2.0"
|
2324 |
rpds-py = ">=0.7.0"
|
2325 |
|
2326 |
+
[[package]]
|
2327 |
+
name = "regex"
|
2328 |
+
version = "2024.5.15"
|
2329 |
+
description = "Alternative regular expression module, to replace re."
|
2330 |
+
optional = false
|
2331 |
+
python-versions = ">=3.8"
|
2332 |
+
files = [
|
2333 |
+
{file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"},
|
2334 |
+
{file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"},
|
2335 |
+
{file = "regex-2024.5.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0bd000c6e266927cb7a1bc39d55be95c4b4f65c5be53e659537537e019232b1"},
|
2336 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eaa7ddaf517aa095fa8da0b5015c44d03da83f5bd49c87961e3c997daed0de7"},
|
2337 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba68168daedb2c0bab7fd7e00ced5ba90aebf91024dea3c88ad5063c2a562cca"},
|
2338 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e8d717bca3a6e2064fc3a08df5cbe366369f4b052dcd21b7416e6d71620dca1"},
|
2339 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1337b7dbef9b2f71121cdbf1e97e40de33ff114801263b275aafd75303bd62b5"},
|
2340 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9ebd0a36102fcad2f03696e8af4ae682793a5d30b46c647eaf280d6cfb32796"},
|
2341 |
+
{file = "regex-2024.5.15-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9efa1a32ad3a3ea112224897cdaeb6aa00381627f567179c0314f7b65d354c62"},
|
2342 |
+
{file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1595f2d10dff3d805e054ebdc41c124753631b6a471b976963c7b28543cf13b0"},
|
2343 |
+
{file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b802512f3e1f480f41ab5f2cfc0e2f761f08a1f41092d6718868082fc0d27143"},
|
2344 |
+
{file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a0981022dccabca811e8171f913de05720590c915b033b7e601f35ce4ea7019f"},
|
2345 |
+
{file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:19068a6a79cf99a19ccefa44610491e9ca02c2be3305c7760d3831d38a467a6f"},
|
2346 |
+
{file = "regex-2024.5.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b5269484f6126eee5e687785e83c6b60aad7663dafe842b34691157e5083e53"},
|
2347 |
+
{file = "regex-2024.5.15-cp310-cp310-win32.whl", hash = "sha256:ada150c5adfa8fbcbf321c30c751dc67d2f12f15bd183ffe4ec7cde351d945b3"},
|
2348 |
+
{file = "regex-2024.5.15-cp310-cp310-win_amd64.whl", hash = "sha256:ac394ff680fc46b97487941f5e6ae49a9f30ea41c6c6804832063f14b2a5a145"},
|
2349 |
+
{file = "regex-2024.5.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f5b1dff3ad008dccf18e652283f5e5339d70bf8ba7c98bf848ac33db10f7bc7a"},
|
2350 |
+
{file = "regex-2024.5.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c6a2b494a76983df8e3d3feea9b9ffdd558b247e60b92f877f93a1ff43d26656"},
|
2351 |
+
{file = "regex-2024.5.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a32b96f15c8ab2e7d27655969a23895eb799de3665fa94349f3b2fbfd547236f"},
|
2352 |
+
{file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35"},
|
2353 |
+
{file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec54d5afa89c19c6dd8541a133be51ee1017a38b412b1321ccb8d6ddbeb4cf7d"},
|
2354 |
+
{file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10e4ce0dca9ae7a66e6089bb29355d4432caed736acae36fef0fdd7879f0b0cb"},
|
2355 |
+
{file = "regex-2024.5.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e507ff1e74373c4d3038195fdd2af30d297b4f0950eeda6f515ae3d84a1770f"},
|
2356 |
+
{file = "regex-2024.5.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1f059a4d795e646e1c37665b9d06062c62d0e8cc3c511fe01315973a6542e40"},
|
2357 |
+
{file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649"},
|
2358 |
+
{file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:833616ddc75ad595dee848ad984d067f2f31be645d603e4d158bba656bbf516c"},
|
2359 |
+
{file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:287eb7f54fc81546346207c533ad3c2c51a8d61075127d7f6d79aaf96cdee890"},
|
2360 |
+
{file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:19dfb1c504781a136a80ecd1fff9f16dddf5bb43cec6871778c8a907a085bb3d"},
|
2361 |
+
{file = "regex-2024.5.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:119af6e56dce35e8dfb5222573b50c89e5508d94d55713c75126b753f834de68"},
|
2362 |
+
{file = "regex-2024.5.15-cp311-cp311-win32.whl", hash = "sha256:1c1c174d6ec38d6c8a7504087358ce9213d4332f6293a94fbf5249992ba54efa"},
|
2363 |
+
{file = "regex-2024.5.15-cp311-cp311-win_amd64.whl", hash = "sha256:9e717956dcfd656f5055cc70996ee2cc82ac5149517fc8e1b60261b907740201"},
|
2364 |
+
{file = "regex-2024.5.15-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:632b01153e5248c134007209b5c6348a544ce96c46005d8456de1d552455b014"},
|
2365 |
+
{file = "regex-2024.5.15-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e64198f6b856d48192bf921421fdd8ad8eb35e179086e99e99f711957ffedd6e"},
|
2366 |
+
{file = "regex-2024.5.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68811ab14087b2f6e0fc0c2bae9ad689ea3584cad6917fc57be6a48bbd012c49"},
|
2367 |
+
{file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8ec0c2fea1e886a19c3bee0cd19d862b3aa75dcdfb42ebe8ed30708df64687a"},
|
2368 |
+
{file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0c0c0003c10f54a591d220997dd27d953cd9ccc1a7294b40a4be5312be8797b"},
|
2369 |
+
{file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2431b9e263af1953c55abbd3e2efca67ca80a3de8a0437cb58e2421f8184717a"},
|
2370 |
+
{file = "regex-2024.5.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a605586358893b483976cffc1723fb0f83e526e8f14c6e6614e75919d9862cf"},
|
2371 |
+
{file = "regex-2024.5.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391d7f7f1e409d192dba8bcd42d3e4cf9e598f3979cdaed6ab11288da88cb9f2"},
|
2372 |
+
{file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9ff11639a8d98969c863d4617595eb5425fd12f7c5ef6621a4b74b71ed8726d5"},
|
2373 |
+
{file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4eee78a04e6c67e8391edd4dad3279828dd66ac4b79570ec998e2155d2e59fd5"},
|
2374 |
+
{file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8fe45aa3f4aa57faabbc9cb46a93363edd6197cbc43523daea044e9ff2fea83e"},
|
2375 |
+
{file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:d0a3d8d6acf0c78a1fff0e210d224b821081330b8524e3e2bc5a68ef6ab5803d"},
|
2376 |
+
{file = "regex-2024.5.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c486b4106066d502495b3025a0a7251bf37ea9540433940a23419461ab9f2a80"},
|
2377 |
+
{file = "regex-2024.5.15-cp312-cp312-win32.whl", hash = "sha256:c49e15eac7c149f3670b3e27f1f28a2c1ddeccd3a2812cba953e01be2ab9b5fe"},
|
2378 |
+
{file = "regex-2024.5.15-cp312-cp312-win_amd64.whl", hash = "sha256:673b5a6da4557b975c6c90198588181029c60793835ce02f497ea817ff647cb2"},
|
2379 |
+
{file = "regex-2024.5.15-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:87e2a9c29e672fc65523fb47a90d429b70ef72b901b4e4b1bd42387caf0d6835"},
|
2380 |
+
{file = "regex-2024.5.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c3bea0ba8b73b71b37ac833a7f3fd53825924165da6a924aec78c13032f20850"},
|
2381 |
+
{file = "regex-2024.5.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfc4f82cabe54f1e7f206fd3d30fda143f84a63fe7d64a81558d6e5f2e5aaba9"},
|
2382 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5bb9425fe881d578aeca0b2b4b3d314ec88738706f66f219c194d67179337cb"},
|
2383 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64c65783e96e563103d641760664125e91bd85d8e49566ee560ded4da0d3e704"},
|
2384 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf2430df4148b08fb4324b848672514b1385ae3807651f3567871f130a728cc3"},
|
2385 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5397de3219a8b08ae9540c48f602996aa6b0b65d5a61683e233af8605c42b0f2"},
|
2386 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:455705d34b4154a80ead722f4f185b04c4237e8e8e33f265cd0798d0e44825fa"},
|
2387 |
+
{file = "regex-2024.5.15-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2b6f1b3bb6f640c1a92be3bbfbcb18657b125b99ecf141fb3310b5282c7d4ed"},
|
2388 |
+
{file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3ad070b823ca5890cab606c940522d05d3d22395d432f4aaaf9d5b1653e47ced"},
|
2389 |
+
{file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5b5467acbfc153847d5adb21e21e29847bcb5870e65c94c9206d20eb4e99a384"},
|
2390 |
+
{file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e6662686aeb633ad65be2a42b4cb00178b3fbf7b91878f9446075c404ada552f"},
|
2391 |
+
{file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:2b4c884767504c0e2401babe8b5b7aea9148680d2e157fa28f01529d1f7fcf67"},
|
2392 |
+
{file = "regex-2024.5.15-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3cd7874d57f13bf70078f1ff02b8b0aa48d5b9ed25fc48547516c6aba36f5741"},
|
2393 |
+
{file = "regex-2024.5.15-cp38-cp38-win32.whl", hash = "sha256:e4682f5ba31f475d58884045c1a97a860a007d44938c4c0895f41d64481edbc9"},
|
2394 |
+
{file = "regex-2024.5.15-cp38-cp38-win_amd64.whl", hash = "sha256:d99ceffa25ac45d150e30bd9ed14ec6039f2aad0ffa6bb87a5936f5782fc1569"},
|
2395 |
+
{file = "regex-2024.5.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13cdaf31bed30a1e1c2453ef6015aa0983e1366fad2667657dbcac7b02f67133"},
|
2396 |
+
{file = "regex-2024.5.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cac27dcaa821ca271855a32188aa61d12decb6fe45ffe3e722401fe61e323cd1"},
|
2397 |
+
{file = "regex-2024.5.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7dbe2467273b875ea2de38ded4eba86cbcbc9a1a6d0aa11dcf7bd2e67859c435"},
|
2398 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64f18a9a3513a99c4bef0e3efd4c4a5b11228b48aa80743be822b71e132ae4f5"},
|
2399 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d347a741ea871c2e278fde6c48f85136c96b8659b632fb57a7d1ce1872547600"},
|
2400 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1878b8301ed011704aea4c806a3cadbd76f84dece1ec09cc9e4dc934cfa5d4da"},
|
2401 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4babf07ad476aaf7830d77000874d7611704a7fcf68c9c2ad151f5d94ae4bfc4"},
|
2402 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35cb514e137cb3488bce23352af3e12fb0dbedd1ee6e60da053c69fb1b29cc6c"},
|
2403 |
+
{file = "regex-2024.5.15-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cdd09d47c0b2efee9378679f8510ee6955d329424c659ab3c5e3a6edea696294"},
|
2404 |
+
{file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:72d7a99cd6b8f958e85fc6ca5b37c4303294954eac1376535b03c2a43eb72629"},
|
2405 |
+
{file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a094801d379ab20c2135529948cb84d417a2169b9bdceda2a36f5f10977ebc16"},
|
2406 |
+
{file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c0c18345010870e58238790a6779a1219b4d97bd2e77e1140e8ee5d14df071aa"},
|
2407 |
+
{file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:16093f563098448ff6b1fa68170e4acbef94e6b6a4e25e10eae8598bb1694b5d"},
|
2408 |
+
{file = "regex-2024.5.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e38a7d4e8f633a33b4c7350fbd8bad3b70bf81439ac67ac38916c4a86b465456"},
|
2409 |
+
{file = "regex-2024.5.15-cp39-cp39-win32.whl", hash = "sha256:71a455a3c584a88f654b64feccc1e25876066c4f5ef26cd6dd711308aa538694"},
|
2410 |
+
{file = "regex-2024.5.15-cp39-cp39-win_amd64.whl", hash = "sha256:cab12877a9bdafde5500206d1020a584355a97884dfd388af3699e9137bf7388"},
|
2411 |
+
{file = "regex-2024.5.15.tar.gz", hash = "sha256:d3ee02d9e5f482cc8309134a91eeaacbdd2261ba111b0fef3748eeb4913e6a2c"},
|
2412 |
+
]
|
2413 |
+
|
2414 |
[[package]]
|
2415 |
name = "requests"
|
2416 |
version = "2.31.0"
|
|
|
3287 |
[metadata]
|
3288 |
lock-version = "2.0"
|
3289 |
python-versions = "^3.12"
|
3290 |
+
content-hash = "988f4561272067771efc60acdb2687f0586be48c1bf401452696c51e8f69b534"
|
pyproject.toml
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
[tool.poetry]
|
2 |
name = "sentiment-analysis"
|
3 |
package-mode = false
|
4 |
-
packages = [{ include = "app" }]
|
5 |
|
6 |
[tool.poetry.dependencies]
|
7 |
python = "^3.12"
|
8 |
click = "^8.1.7"
|
9 |
scikit-learn = "^1.4.2"
|
10 |
gradio = "^4.31.0"
|
|
|
|
|
11 |
|
12 |
[tool.poetry.group.train.dependencies]
|
13 |
pandas = "^2.2.2"
|
|
|
1 |
[tool.poetry]
|
2 |
name = "sentiment-analysis"
|
3 |
package-mode = false
|
|
|
4 |
|
5 |
[tool.poetry.dependencies]
|
6 |
python = "^3.12"
|
7 |
click = "^8.1.7"
|
8 |
scikit-learn = "^1.4.2"
|
9 |
gradio = "^4.31.0"
|
10 |
+
colorama = "^0.4.6"
|
11 |
+
nltk = "^3.8.1"
|
12 |
|
13 |
[tool.poetry.group.train.dependencies]
|
14 |
pandas = "^2.2.2"
|