xinah3131 commited on
Commit
1c6bb71
1 Parent(s): b68bb4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -9,8 +9,10 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  import seaborn as sns
11
  # Load the model
12
-
13
- model = joblib.load('85pct(new).pkl')
 
 
14
 
15
  # Define the categories
16
  categories = {
@@ -114,6 +116,8 @@ def main():
114
 
115
  # Input fields
116
  if selected_option == "Predict":
 
 
117
  with st.container():
118
  col1, col2, col3 = st.columns(3)
119
  getTitle, getDuration, getCategory = "", 0.00, 1
@@ -161,7 +165,7 @@ def main():
161
  st.warning("Please enter a duration.")
162
 
163
  else:
164
- prediction = predict_trend(title, duration, categoryId)
165
  if prediction[0] == 1:
166
  st.success("This video is predicted to be a trend!")
167
  st.markdown("![Alt Text](https://media.tenor.com/Cyi2zT7wcmcAAAAj/pentol-gif-eak.gif)")
@@ -182,7 +186,7 @@ def main():
182
  st.download_button(
183
  "Download",
184
  csv,
185
- "file.csv",
186
  "text/csv",
187
  key='download-csv'
188
  )
@@ -241,7 +245,7 @@ def main():
241
  show_top_titleLength()
242
  with col13:
243
  show_best_titleLength()
244
-
245
  def get_picture_from_url(url):
246
  try:
247
  response = requests.get(url)
@@ -388,7 +392,7 @@ def show_best_titleLength():
388
  st.write(f"<span style='color:{color};font-weight:bold;'>{range}</span>", unsafe_allow_html=True)
389
 
390
  # Function to make predictions
391
- def predict_trend(title, duration, category_id):
392
  duration = str(duration)
393
  category_id = int(category_id)
394
  clean_new_title = preprocess(title)
 
9
  import numpy as np
10
  import seaborn as sns
11
  # Load the model
12
+ def read_model(region):
13
+ if(region == "United States"):
14
+ model = joblib.load('85pct(2).pkl')
15
+ return model
16
 
17
  # Define the categories
18
  categories = {
 
116
 
117
  # Input fields
118
  if selected_option == "Predict":
119
+ region = st.sidebar.selectbox("Select Region", ['United States'])
120
+ model = read_model(region)
121
  with st.container():
122
  col1, col2, col3 = st.columns(3)
123
  getTitle, getDuration, getCategory = "", 0.00, 1
 
165
  st.warning("Please enter a duration.")
166
 
167
  else:
168
+ prediction = predict_trend(model,title, duration, categoryId)
169
  if prediction[0] == 1:
170
  st.success("This video is predicted to be a trend!")
171
  st.markdown("![Alt Text](https://media.tenor.com/Cyi2zT7wcmcAAAAj/pentol-gif-eak.gif)")
 
186
  st.download_button(
187
  "Download",
188
  csv,
189
+ "top10Trending.csv",
190
  "text/csv",
191
  key='download-csv'
192
  )
 
245
  show_top_titleLength()
246
  with col13:
247
  show_best_titleLength()
248
+
249
  def get_picture_from_url(url):
250
  try:
251
  response = requests.get(url)
 
392
  st.write(f"<span style='color:{color};font-weight:bold;'>{range}</span>", unsafe_allow_html=True)
393
 
394
  # Function to make predictions
395
+ def predict_trend(model,title, duration, category_id):
396
  duration = str(duration)
397
  category_id = int(category_id)
398
  clean_new_title = preprocess(title)