alexander-lazarin commited on
Commit
2955021
1 Parent(s): 616a337

add the amazon chronos model

Browse files
Files changed (2) hide show
  1. app.py +92 -66
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,30 +6,24 @@ import re
6
  import logging
7
  import psycopg2
8
  import os
 
 
 
9
 
10
  try:
11
- from google.colab import userdata
12
- PG_PASSWORD = userdata.get('FASHION_PG_PASS')
13
  except:
14
- PG_PASSWORD = os.environ['FASHION_PG_PASS']
15
 
16
  logging.getLogger("prophet").setLevel(logging.WARNING)
17
  logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
18
 
19
  # Dictionary to map Russian month names to month numbers
20
  russian_months = {
21
- "январь": "01",
22
- "февраль": "02",
23
- "март": "03",
24
- "апрель": "04",
25
- "май": "05",
26
- "июнь": "06",
27
- "июль": "07",
28
- "август": "08",
29
- "сентябрь": "09",
30
- "октябрь": "10",
31
- "ноябрь": "11",
32
- "декабрь": "12"
33
  }
34
 
35
  def read_and_process_file(file):
@@ -83,8 +77,7 @@ def get_data_from_db(query):
83
  conn.close()
84
  return data
85
 
86
- def forecast_time_series(file, product_name, wb, ozon):
87
-
88
  if file is None:
89
  # Construct the query
90
  marketplaces = []
@@ -113,83 +106,109 @@ def forecast_time_series(file, product_name, wb, ozon):
113
 
114
  data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
115
  data.set_index('ds', inplace=True)
116
-
117
  else:
118
  data, period_type, period_col = read_and_process_file(file)
119
 
120
-
121
-
122
  if period_type == "Month":
123
  year = 12
124
  n_periods = 24
125
  freq = "MS"
126
  else:
127
  year = 52
128
- n_periods = year * 2 # Number of periods to forecast
129
  freq = "W"
130
 
131
- # Prepare data for Prophet
132
  df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
133
 
134
- # Fit the Prophet model
135
- print(df.dtypes)
136
- print(df.head(3))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  model = Prophet()
138
  model.fit(df)
139
-
140
- # Create future dataframe
141
  future = model.make_future_dataframe(periods=n_periods, freq=freq)
142
-
143
- # Forecasting
144
  forecast = model.predict(future)
145
-
146
- # Calculate the YoY change
147
  sum_last_year_original = df['y'].iloc[-year:].sum()
148
  sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
149
  yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # Create an interactive plot with Plotly
152
  fig = go.Figure()
153
  fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
154
  fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
155
- fig.add_trace(go.Scatter(
156
- x=forecast['ds'],
157
- y=forecast['yhat_lower'],
158
- fill=None,
159
- mode='lines',
160
- line=dict(color='pink'),
161
- # showlegend=False,
162
- name='Lower CI'
163
- ))
164
- fig.add_trace(go.Scatter(
165
- x=forecast['ds'],
166
- y=forecast['yhat_upper'],
167
- fill='tonexty',
168
- mode='lines',
169
- line=dict(color='pink'),
170
- name='Upper CI'
171
- ))
172
  fig.update_layout(
173
  title='Observed Time Series and Forecast with Confidence Intervals',
174
  xaxis_title='Date',
175
  yaxis_title='Values',
176
- legend=dict(
177
- orientation='h',
178
- yanchor='bottom',
179
- y=1.02,
180
- xanchor='right',
181
- x=1
182
- ),
183
  hovermode='x unified'
184
  )
185
-
186
- # Combine original data and forecast data into one DataFrame for export
187
- combined_df = pd.concat([data, forecast.set_index('ds')[['yhat', 'yhat_lower', 'yhat_upper']]], axis=1)
188
- combined_file = 'combined_data.csv'
189
- combined_df.to_csv(combined_file)
190
-
191
- # Return plot, YoY change, and file path for export
192
- return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
193
 
194
  # Create Gradio interface using Blocks
195
  with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
@@ -206,6 +225,9 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
206
  with gr.Row():
207
  product_name_input = gr.Textbox(label="Product Name Filter", value="product_name like '%пуховик%'")
208
 
 
 
 
209
  with gr.Row():
210
  compute_button = gr.Button("Compute")
211
 
@@ -218,7 +240,11 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
218
  with gr.Row():
219
  csv_output = gr.File(label="Download Combined Data CSV")
220
 
221
- compute_button.click(forecast_time_series, inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox], outputs=[plot_output, yoy_output, csv_output])
 
 
 
 
222
 
223
  # Launch the interface
224
- interface.launch(debug=True)
 
6
  import logging
7
  import psycopg2
8
  import os
9
+ import torch
10
+ from chronos import ChronosPipeline
11
+ import numpy as np
12
 
13
  try:
14
+ from google.colab import userdata
15
+ PG_PASSWORD = userdata.get('FASHION_PG_PASS')
16
  except:
17
+ PG_PASSWORD = os.environ['FASHION_PG_PASS']
18
 
19
  logging.getLogger("prophet").setLevel(logging.WARNING)
20
  logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
21
 
22
  # Dictionary to map Russian month names to month numbers
23
  russian_months = {
24
+ "январь": "01", "февраль": "02", "март": "03", "апрель": "04",
25
+ "май": "05", "июнь": "06", "июль": "07", "август": "08",
26
+ "сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12"
 
 
 
 
 
 
 
 
 
27
  }
28
 
29
  def read_and_process_file(file):
 
77
  conn.close()
78
  return data
79
 
80
+ def forecast_time_series(file, product_name, wb, ozon, model_choice):
 
81
  if file is None:
82
  # Construct the query
83
  marketplaces = []
 
106
 
107
  data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
108
  data.set_index('ds', inplace=True)
 
109
  else:
110
  data, period_type, period_col = read_and_process_file(file)
111
 
 
 
112
  if period_type == "Month":
113
  year = 12
114
  n_periods = 24
115
  freq = "MS"
116
  else:
117
  year = 52
118
+ n_periods = year * 2
119
  freq = "W"
120
 
 
121
  df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
122
 
123
+ if model_choice == "Prophet":
124
+ forecast, yoy_change = forecast_prophet(df, n_periods, freq, year)
125
+ elif model_choice == "Chronos":
126
+ forecast, yoy_change = forecast_chronos(df, n_periods, freq, year)
127
+ else:
128
+ raise ValueError("Invalid model choice")
129
+
130
+ # Create Plotly figure (common for both models)
131
+ fig = create_plot(data, forecast)
132
+
133
+ # Combine original data and forecast
134
+ combined_df = pd.concat([data, forecast.set_index('ds')], axis=1)
135
+
136
+ # Save combined data
137
+ combined_file = 'combined_data.csv'
138
+ combined_df.to_csv(combined_file)
139
+
140
+ return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
141
+
142
+ def forecast_prophet(df, n_periods, freq, year):
143
  model = Prophet()
144
  model.fit(df)
 
 
145
  future = model.make_future_dataframe(periods=n_periods, freq=freq)
 
 
146
  forecast = model.predict(future)
147
+
 
148
  sum_last_year_original = df['y'].iloc[-year:].sum()
149
  sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
150
  yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
151
+
152
+ return forecast, yoy_change
153
+
154
+ def forecast_chronos(df, n_periods, freq, year):
155
+ pipeline = ChronosPipeline.from_pretrained(
156
+ "amazon/chronos-t5-mini",
157
+ device_map="cpu",
158
+ torch_dtype=torch.bfloat16,
159
+ )
160
+
161
+ # Check for non-numeric values
162
+ if not pd.api.types.is_numeric_dtype(df['y']):
163
+ non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()]
164
+ if not non_numeric.empty:
165
+ error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}"
166
+ raise ValueError(error_message)
167
+
168
+ try:
169
+ y_values = df['y'].values.astype(np.float32)
170
+ except ValueError as e:
171
+ raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}")
172
+
173
+ chronos_forecast = pipeline.predict(
174
+ context=torch.tensor(y_values),
175
+ prediction_length=n_periods,
176
+ num_samples=20,
177
+ limit_prediction_length=False
178
+ )
179
+
180
+ forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:]
181
+ low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
182
+
183
+ forecast = pd.DataFrame({
184
+ 'ds': forecast_index,
185
+ 'yhat': median,
186
+ 'yhat_lower': low,
187
+ 'yhat_upper': high
188
+ })
189
+
190
+ sum_last_year_original = df['y'].iloc[-year:].sum()
191
+ sum_first_year_forecast = median[:year].sum()
192
+ yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
193
+
194
+ return forecast, yoy_change
195
 
196
+ def create_plot(data, forecast):
197
  fig = go.Figure()
198
  fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
199
  fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
200
+ fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI'))
201
+ fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI'))
202
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  fig.update_layout(
204
  title='Observed Time Series and Forecast with Confidence Intervals',
205
  xaxis_title='Date',
206
  yaxis_title='Values',
207
+ legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
 
 
 
 
 
 
208
  hovermode='x unified'
209
  )
210
+
211
+ return fig
 
 
 
 
 
 
212
 
213
  # Create Gradio interface using Blocks
214
  with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
 
225
  with gr.Row():
226
  product_name_input = gr.Textbox(label="Product Name Filter", value="product_name like '%пуховик%'")
227
 
228
+ with gr.Row():
229
+ model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet")
230
+
231
  with gr.Row():
232
  compute_button = gr.Button("Compute")
233
 
 
240
  with gr.Row():
241
  csv_output = gr.File(label="Download Combined Data CSV")
242
 
243
+ compute_button.click(
244
+ forecast_time_series,
245
+ inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice],
246
+ outputs=[plot_output, yoy_output, csv_output]
247
+ )
248
 
249
  # Launch the interface
250
+ interface.launch(debug=True)
requirements.txt CHANGED
@@ -4,3 +4,4 @@ pandas
4
  plotly
5
  prophet
6
  psycopg2
 
 
4
  plotly
5
  prophet
6
  psycopg2
7
+ git+https://github.com/amazon-science/chronos-forecasting.git