Spaces:
Runtime error
Runtime error
Add app
Browse files- app.py +103 -0
- requirements.txt +22 -0
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import hopsworks
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
import seaborn as sns
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
#from dotenv import load_dotenv
|
9 |
+
#load_dotenv()
|
10 |
+
|
11 |
+
@st.experimental_memo
|
12 |
+
def load_data():
|
13 |
+
project = hopsworks.login()
|
14 |
+
fs = project.get_feature_store()
|
15 |
+
|
16 |
+
#if not os.path.isfile("./cache/batch_data.pkl"):
|
17 |
+
# if not os.path.isdir("./cache"):
|
18 |
+
# os.mkdir("./cache")
|
19 |
+
posts_fg = fs.get_feature_group("reddit_posts", version=os.getenv("POSTS_FG_VERSION", default=1))
|
20 |
+
users_fg = fs.get_feature_group("reddit_users", version=os.getenv("USERS_FG_VERSION", default=1))
|
21 |
+
subreddits_fg = fs.get_feature_group("reddit_subreddits", version=os.getenv("SUBREDDITS_FG_VERSION", default=1))
|
22 |
+
full_join = posts_fg.select(features=["post_id", "snapshot_time", "num_likes", "upvote_ratio"]).join(
|
23 |
+
users_fg.select(features=["user_id", "snapshot_time"]), on=["user_id", "snapshot_time"]).join(
|
24 |
+
subreddits_fg.select(features=["subreddit_id", "snapshot_time"]), on=["subreddit_id", "snapshot_time"])
|
25 |
+
df = full_join.read()
|
26 |
+
|
27 |
+
# df.to_pickle("./cache/batch_data.pkl")
|
28 |
+
#else:
|
29 |
+
# df = pd.read_pickle("./cache/batch_data.pkl")
|
30 |
+
|
31 |
+
# Load model including the generated images and evaluation scores
|
32 |
+
mr = project.get_model_registry()
|
33 |
+
model_hsfs = mr.get_model("reddit_predict", version=16)
|
34 |
+
model_dir = model_hsfs.download()
|
35 |
+
print("Model directory: {}".format(model_dir))
|
36 |
+
|
37 |
+
metric_rows = {}
|
38 |
+
metrics_avail = [m.replace("_likes","") for m in model_hsfs.training_metrics if "_likes" in m]
|
39 |
+
for target in ["likes", "upvote_ratio"]:
|
40 |
+
metric_rows[target] = []
|
41 |
+
for metric in metrics_avail:
|
42 |
+
metric_rows[target].append(model_hsfs.training_metrics[f"{metric}_{target}"])
|
43 |
+
df_metrics = pd.DataFrame(metric_rows, index=metrics_avail)
|
44 |
+
|
45 |
+
img_predictions = plt.imread(f"{model_dir}/prediction_error.png")
|
46 |
+
img_predictions_logscale = plt.imread(f"{model_dir}/prediction_error_logscale.png")
|
47 |
+
|
48 |
+
return df, img_predictions, img_predictions_logscale, df_metrics
|
49 |
+
|
50 |
+
|
51 |
+
df, img_predictions, img_predictions_logscale, df_metrics = load_data()
|
52 |
+
|
53 |
+
# create a distribution plot of the number of likes using seaborn
|
54 |
+
st.title("Like It or Not")
|
55 |
+
st.markdown("This is the dashboard for the Like It Or Not model that predict the number of likes and the upvote ratio that a Reddit post is going to get.")
|
56 |
+
|
57 |
+
# Data stats
|
58 |
+
st.markdown("## Data Statistics")
|
59 |
+
col1, col2, col3 = st.columns(3)
|
60 |
+
col1.metric("Unqiue Posts", str(df["post_id"].nunique()))
|
61 |
+
col2.metric("Unique Users", str(df["user_id"].nunique()))
|
62 |
+
col3.metric("Unique Subreddits", str(df["subreddit_id"].nunique()))
|
63 |
+
|
64 |
+
# Distribution of the target variables
|
65 |
+
col1, col2 = st.columns(2)
|
66 |
+
col1.markdown("### Distribution of Number of Likes")
|
67 |
+
col2.markdown("### Distribution of Upvote Ratio")
|
68 |
+
col1, col2 = st.columns(2)
|
69 |
+
fig, ax = plt.subplots()
|
70 |
+
sns.histplot(df["num_likes"], ax=ax)
|
71 |
+
ax.set_ylabel("Number of posts")
|
72 |
+
ax.set_xlabel("Number of likes (log scale)")
|
73 |
+
ax.set_xscale("log")
|
74 |
+
plt.tight_layout()
|
75 |
+
col1.pyplot(fig)
|
76 |
+
|
77 |
+
fig2, ax = plt.subplots()
|
78 |
+
sns.distplot(df["upvote_ratio"], ax=ax)
|
79 |
+
ax.set_ylabel("Number of posts")
|
80 |
+
plt.tight_layout()
|
81 |
+
col2.pyplot(fig2)
|
82 |
+
|
83 |
+
# Performance metrics
|
84 |
+
st.markdown("## Performance Metrics")
|
85 |
+
st.markdown("The model achieved the below scores on the test set. Please keep the effect of the sample weights in mind as explained in the Github repository. These reduce for example the R2 score from 0.75 to roughly 0. However, despite these low scores, the model is more useful in practice as it provides a meaningful lower bound estimate of the likes to be received as opposed to overestimating every post by up to 1500")
|
86 |
+
st.dataframe(df_metrics)
|
87 |
+
|
88 |
+
# Prediction error plots
|
89 |
+
st.markdown("## Prediction Error Plots")
|
90 |
+
st.markdown("The green line indicates the perfect prediction while the blue lines show point densities. Every point represents a prediction. The model is optimized for the number of likes and provides an estimate for the minimum number of likes expected. The upvote ratio does not perform well and would profit from dedicated modeling with another objective function if it is important.")
|
91 |
+
st.markdown("### Linear Scale")
|
92 |
+
st.image(img_predictions)
|
93 |
+
st.markdown("### Log Scale")
|
94 |
+
st.image(img_predictions_logscale)
|
95 |
+
|
96 |
+
# Confusion matrix
|
97 |
+
#st.markdown("## Confusion Matrix")
|
98 |
+
#st.markdown("The confusion matrix of the model is as follows:")
|
99 |
+
#st.image("confusion_matrix.png")
|
100 |
+
|
101 |
+
# display the evaluation scores table
|
102 |
+
#st.title("Evaluation Scores")
|
103 |
+
#st.dataframe(df[["metric1", "metric2", "metric3", "metric4"]])
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bayesian_optimization==1.4.2
|
2 |
+
great_expectations==0.14.12
|
3 |
+
hopsworks==3.0.5
|
4 |
+
hsml==3.0.3
|
5 |
+
joblib==1.2.0
|
6 |
+
matplotlib==3.5.3
|
7 |
+
modal==0.4.1
|
8 |
+
modal_client==0.44.679
|
9 |
+
numpy==1.21.6
|
10 |
+
pandas==1.3.5
|
11 |
+
praw==7.6.1
|
12 |
+
pytest==7.2.0
|
13 |
+
python-dotenv==0.21.0
|
14 |
+
scikit_learn==1.2.0
|
15 |
+
scipy==1.7.3
|
16 |
+
seaborn==0.12.2
|
17 |
+
sentence_transformers==2.2.2
|
18 |
+
shap==0.41.0
|
19 |
+
streamlit==1.16.0
|
20 |
+
tqdm==4.64.1
|
21 |
+
transformers==4.25.1
|
22 |
+
xgboost==1.6.2
|