Spaces:
Running
Running
Add amazonreviews model
Browse files- README.md +1 -2
- app/cli.py +3 -4
- app/model.py +25 -59
- models/amazonreviews_tfidf_ft20000.pkl +3 -0
README.md
CHANGED
@@ -213,8 +213,7 @@ The following pre-trained models are available for use:
|
|
213 |
| --- | --- | --- | --- | --- | --- | --- |
|
214 |
| `imdb50k` | `tfidf` | `LinearRegression` | 20 000 | 83.24% ± 0.99% | 89.24% ± 0.13% | [Here](models/imdb50k_tfidf_ft20000.pkl) |
|
215 |
| `sentiment140` | `tfidf` | `LinearRegression` | 20 000 | 83.24% ± 0.99% | 77.32% ± 0.28% | [Here](models/sentiment140_tfidf_ft20000.pkl) |
|
216 |
-
| `amazonreviews` | `tfidf` | `LinearRegression` | 20 000 |
|
217 |
-
|
218 |
|
219 |
## License
|
220 |
Distributed under the MIT License. See [LICENSE](LICENSE) for more information.
|
|
|
213 |
| --- | --- | --- | --- | --- | --- | --- |
|
214 |
| `imdb50k` | `tfidf` | `LinearRegression` | 20 000 | 83.24% ± 0.99% | 89.24% ± 0.13% | [Here](models/imdb50k_tfidf_ft20000.pkl) |
|
215 |
| `sentiment140` | `tfidf` | `LinearRegression` | 20 000 | 83.24% ± 0.99% | 77.32% ± 0.28% | [Here](models/sentiment140_tfidf_ft20000.pkl) |
|
216 |
+
| `amazonreviews` | `tfidf` | `LinearRegression` | 20 000 | 82.17% ± 0.85% | ❌ | [Here](models/amazonreviews_tfidf_ft20000.pkl) |
|
|
|
217 |
|
218 |
## License
|
219 |
Distributed under the MIT License. See [LICENSE](LICENSE) for more information.
|
app/cli.py
CHANGED
@@ -217,10 +217,9 @@ def evaluate(
|
|
217 |
)
|
218 |
@click.option(
|
219 |
"--min-df",
|
220 |
-
default=
|
221 |
-
help="Minimum document frequency for the
|
222 |
show_default=True,
|
223 |
-
type=click.FloatRange(0, 1),
|
224 |
)
|
225 |
@click.option(
|
226 |
"--cv",
|
@@ -268,7 +267,7 @@ def train(
|
|
268 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
269 |
vectorizer: Literal["tfidf", "count", "hashing"],
|
270 |
max_features: int,
|
271 |
-
min_df:
|
272 |
cv: int,
|
273 |
token_batch_size: int,
|
274 |
token_jobs: int,
|
|
|
217 |
)
|
218 |
@click.option(
|
219 |
"--min-df",
|
220 |
+
default=5,
|
221 |
+
help="Minimum document frequency for the features (ignored for hashing)",
|
222 |
show_default=True,
|
|
|
223 |
)
|
224 |
@click.option(
|
225 |
"--cv",
|
|
|
267 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
268 |
vectorizer: Literal["tfidf", "count", "hashing"],
|
269 |
max_features: int,
|
270 |
+
min_df: int,
|
271 |
cv: int,
|
272 |
token_batch_size: int,
|
273 |
token_jobs: int,
|
app/model.py
CHANGED
@@ -10,7 +10,6 @@ from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer,
|
|
10 |
from sklearn.linear_model import LogisticRegression
|
11 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
12 |
from sklearn.pipeline import Pipeline
|
13 |
-
from tqdm import tqdm
|
14 |
|
15 |
from app.constants import CACHE_DIR
|
16 |
from app.data import tokenize
|
@@ -36,7 +35,7 @@ def _identity(x: list[str]) -> list[str]:
|
|
36 |
def _get_vectorizer(
|
37 |
name: Literal["tfidf", "count", "hashing"],
|
38 |
n_features: int,
|
39 |
-
min_df:
|
40 |
ngram: tuple[int, int] = (1, 2),
|
41 |
) -> TransformerMixin:
|
42 |
"""Get the appropriate vectorizer.
|
@@ -96,7 +95,7 @@ def train_model(
|
|
96 |
label_data: list[int],
|
97 |
vectorizer: Literal["tfidf", "count", "hashing"],
|
98 |
max_features: int,
|
99 |
-
min_df:
|
100 |
folds: int = 5,
|
101 |
n_jobs: int = 4,
|
102 |
seed: int = 42,
|
@@ -129,66 +128,33 @@ def train_model(
|
|
129 |
)
|
130 |
|
131 |
vectorizer = _get_vectorizer(vectorizer, max_features, min_df)
|
132 |
-
|
133 |
-
|
134 |
-
# (LinearSVC(max_iter=10000, random_state=rs), {"C": np.logspace(-4, 4, 20)}),
|
135 |
-
# (KNeighborsClassifier(), {"n_neighbors": np.arange(1, 10)}),
|
136 |
-
# (RandomForestClassifier(random_state=rs), {"n_estimators": np.arange(50, 500, 50)}),
|
137 |
-
# (
|
138 |
-
# VotingClassifier(
|
139 |
-
# estimators=[
|
140 |
-
# ("lr", LogisticRegression(max_iter=1000, random_state=rs)),
|
141 |
-
# ("knn", KNeighborsClassifier()),
|
142 |
-
# ("rf", RandomForestClassifier(random_state=rs)),
|
143 |
-
# ],
|
144 |
-
# ),
|
145 |
-
# {
|
146 |
-
# "lr__C": np.logspace(-4, 4, 20),
|
147 |
-
# "knn__n_neighbors": np.arange(1, 10),
|
148 |
-
# "rf__n_estimators": np.arange(50, 500, 50),
|
149 |
-
# },
|
150 |
-
# ),
|
151 |
-
]
|
152 |
-
|
153 |
-
models = []
|
154 |
-
for clf, param_dist in (pbar := tqdm(classifiers, unit="clf")):
|
155 |
-
param_dist = {f"classifier__{k}": v for k, v in param_dist.items()}
|
156 |
-
|
157 |
-
model = Pipeline(
|
158 |
-
[("vectorizer", vectorizer), ("classifier", clf)],
|
159 |
-
memory=Memory(CACHE_DIR, verbose=0),
|
160 |
-
)
|
161 |
-
|
162 |
-
search = RandomizedSearchCV(
|
163 |
-
model,
|
164 |
-
param_dist,
|
165 |
-
cv=folds,
|
166 |
-
random_state=rs,
|
167 |
-
n_jobs=n_jobs,
|
168 |
-
# verbose=2,
|
169 |
-
scoring="accuracy",
|
170 |
-
n_iter=10,
|
171 |
-
)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
|
178 |
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
|
185 |
-
|
186 |
-
print("--------------")
|
187 |
-
print("\n".join(f"{model.named_steps['classifier'].__class__.__name__}: {acc:.2%}" for model, acc in models))
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
return best_model, best_acc
|
192 |
|
193 |
|
194 |
def evaluate_model(
|
@@ -211,7 +177,7 @@ def evaluate_model(
|
|
211 |
Mean accuracy and standard deviation
|
212 |
"""
|
213 |
with warnings.catch_warnings():
|
214 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
215 |
scores = cross_val_score(
|
216 |
model,
|
217 |
token_data,
|
|
|
10 |
from sklearn.linear_model import LogisticRegression
|
11 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
12 |
from sklearn.pipeline import Pipeline
|
|
|
13 |
|
14 |
from app.constants import CACHE_DIR
|
15 |
from app.data import tokenize
|
|
|
35 |
def _get_vectorizer(
|
36 |
name: Literal["tfidf", "count", "hashing"],
|
37 |
n_features: int,
|
38 |
+
min_df: int = 5,
|
39 |
ngram: tuple[int, int] = (1, 2),
|
40 |
) -> TransformerMixin:
|
41 |
"""Get the appropriate vectorizer.
|
|
|
95 |
label_data: list[int],
|
96 |
vectorizer: Literal["tfidf", "count", "hashing"],
|
97 |
max_features: int,
|
98 |
+
min_df: int = 5,
|
99 |
folds: int = 5,
|
100 |
n_jobs: int = 4,
|
101 |
seed: int = 42,
|
|
|
128 |
)
|
129 |
|
130 |
vectorizer = _get_vectorizer(vectorizer, max_features, min_df)
|
131 |
+
classifier = LogisticRegression(max_iter=1000, random_state=rs)
|
132 |
+
param_dist = {"classifier__C": np.logspace(-4, 4, 20)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
model = Pipeline(
|
135 |
+
[("vectorizer", vectorizer), ("classifier", classifier)],
|
136 |
+
memory=Memory(CACHE_DIR, verbose=0),
|
137 |
+
)
|
|
|
138 |
|
139 |
+
search = RandomizedSearchCV(
|
140 |
+
model,
|
141 |
+
param_dist,
|
142 |
+
cv=folds,
|
143 |
+
random_state=rs,
|
144 |
+
n_jobs=n_jobs,
|
145 |
+
verbose=2,
|
146 |
+
scoring="accuracy",
|
147 |
+
n_iter=10,
|
148 |
+
)
|
149 |
|
150 |
+
with warnings.catch_warnings():
|
151 |
+
warnings.filterwarnings("once", category=ConvergenceWarning)
|
152 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
|
153 |
|
154 |
+
search.fit(text_train, label_train)
|
|
|
|
|
155 |
|
156 |
+
final_model = search.best_estimator_
|
157 |
+
return final_model, final_model.score(text_test, label_test)
|
|
|
158 |
|
159 |
|
160 |
def evaluate_model(
|
|
|
177 |
Mean accuracy and standard deviation
|
178 |
"""
|
179 |
with warnings.catch_warnings():
|
180 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
|
181 |
scores = cross_val_score(
|
182 |
model,
|
183 |
token_data,
|
models/amazonreviews_tfidf_ft20000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ccc3156426d2086e10c241de5f186c756de550858ee2964471c26d0e24b8996
|
3 |
+
size 442646
|