alexander-lazarin
commited on
Commit
•
2955021
1
Parent(s):
616a337
add the amazon chronos model
Browse files- app.py +92 -66
- requirements.txt +1 -0
app.py
CHANGED
@@ -6,30 +6,24 @@ import re
|
|
6 |
import logging
|
7 |
import psycopg2
|
8 |
import os
|
|
|
|
|
|
|
9 |
|
10 |
try:
|
11 |
-
|
12 |
-
|
13 |
except:
|
14 |
-
|
15 |
|
16 |
logging.getLogger("prophet").setLevel(logging.WARNING)
|
17 |
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
|
18 |
|
19 |
# Dictionary to map Russian month names to month numbers
|
20 |
russian_months = {
|
21 |
-
"январь": "01",
|
22 |
-
"
|
23 |
-
"
|
24 |
-
"апрель": "04",
|
25 |
-
"май": "05",
|
26 |
-
"июнь": "06",
|
27 |
-
"июль": "07",
|
28 |
-
"август": "08",
|
29 |
-
"сентябрь": "09",
|
30 |
-
"октябрь": "10",
|
31 |
-
"ноябрь": "11",
|
32 |
-
"декабрь": "12"
|
33 |
}
|
34 |
|
35 |
def read_and_process_file(file):
|
@@ -83,8 +77,7 @@ def get_data_from_db(query):
|
|
83 |
conn.close()
|
84 |
return data
|
85 |
|
86 |
-
def forecast_time_series(file, product_name, wb, ozon):
|
87 |
-
|
88 |
if file is None:
|
89 |
# Construct the query
|
90 |
marketplaces = []
|
@@ -113,83 +106,109 @@ def forecast_time_series(file, product_name, wb, ozon):
|
|
113 |
|
114 |
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
|
115 |
data.set_index('ds', inplace=True)
|
116 |
-
|
117 |
else:
|
118 |
data, period_type, period_col = read_and_process_file(file)
|
119 |
|
120 |
-
|
121 |
-
|
122 |
if period_type == "Month":
|
123 |
year = 12
|
124 |
n_periods = 24
|
125 |
freq = "MS"
|
126 |
else:
|
127 |
year = 52
|
128 |
-
n_periods = year * 2
|
129 |
freq = "W"
|
130 |
|
131 |
-
# Prepare data for Prophet
|
132 |
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
model = Prophet()
|
138 |
model.fit(df)
|
139 |
-
|
140 |
-
# Create future dataframe
|
141 |
future = model.make_future_dataframe(periods=n_periods, freq=freq)
|
142 |
-
|
143 |
-
# Forecasting
|
144 |
forecast = model.predict(future)
|
145 |
-
|
146 |
-
# Calculate the YoY change
|
147 |
sum_last_year_original = df['y'].iloc[-year:].sum()
|
148 |
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
|
149 |
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
fig = go.Figure()
|
153 |
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
|
154 |
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
|
155 |
-
fig.add_trace(go.Scatter(
|
156 |
-
|
157 |
-
|
158 |
-
fill=None,
|
159 |
-
mode='lines',
|
160 |
-
line=dict(color='pink'),
|
161 |
-
# showlegend=False,
|
162 |
-
name='Lower CI'
|
163 |
-
))
|
164 |
-
fig.add_trace(go.Scatter(
|
165 |
-
x=forecast['ds'],
|
166 |
-
y=forecast['yhat_upper'],
|
167 |
-
fill='tonexty',
|
168 |
-
mode='lines',
|
169 |
-
line=dict(color='pink'),
|
170 |
-
name='Upper CI'
|
171 |
-
))
|
172 |
fig.update_layout(
|
173 |
title='Observed Time Series and Forecast with Confidence Intervals',
|
174 |
xaxis_title='Date',
|
175 |
yaxis_title='Values',
|
176 |
-
legend=dict(
|
177 |
-
orientation='h',
|
178 |
-
yanchor='bottom',
|
179 |
-
y=1.02,
|
180 |
-
xanchor='right',
|
181 |
-
x=1
|
182 |
-
),
|
183 |
hovermode='x unified'
|
184 |
)
|
185 |
-
|
186 |
-
|
187 |
-
combined_df = pd.concat([data, forecast.set_index('ds')[['yhat', 'yhat_lower', 'yhat_upper']]], axis=1)
|
188 |
-
combined_file = 'combined_data.csv'
|
189 |
-
combined_df.to_csv(combined_file)
|
190 |
-
|
191 |
-
# Return plot, YoY change, and file path for export
|
192 |
-
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
|
193 |
|
194 |
# Create Gradio interface using Blocks
|
195 |
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
|
@@ -206,6 +225,9 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
|
|
206 |
with gr.Row():
|
207 |
product_name_input = gr.Textbox(label="Product Name Filter", value="product_name like '%пуховик%'")
|
208 |
|
|
|
|
|
|
|
209 |
with gr.Row():
|
210 |
compute_button = gr.Button("Compute")
|
211 |
|
@@ -218,7 +240,11 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
|
|
218 |
with gr.Row():
|
219 |
csv_output = gr.File(label="Download Combined Data CSV")
|
220 |
|
221 |
-
compute_button.click(
|
|
|
|
|
|
|
|
|
222 |
|
223 |
# Launch the interface
|
224 |
-
interface.launch(debug=True)
|
|
|
6 |
import logging
|
7 |
import psycopg2
|
8 |
import os
|
9 |
+
import torch
|
10 |
+
from chronos import ChronosPipeline
|
11 |
+
import numpy as np
|
12 |
|
13 |
try:
|
14 |
+
from google.colab import userdata
|
15 |
+
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
|
16 |
except:
|
17 |
+
PG_PASSWORD = os.environ['FASHION_PG_PASS']
|
18 |
|
19 |
logging.getLogger("prophet").setLevel(logging.WARNING)
|
20 |
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
|
21 |
|
22 |
# Dictionary to map Russian month names to month numbers
|
23 |
russian_months = {
|
24 |
+
"январь": "01", "февраль": "02", "март": "03", "апрель": "04",
|
25 |
+
"май": "05", "июнь": "06", "июль": "07", "август": "08",
|
26 |
+
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
}
|
28 |
|
29 |
def read_and_process_file(file):
|
|
|
77 |
conn.close()
|
78 |
return data
|
79 |
|
80 |
+
def forecast_time_series(file, product_name, wb, ozon, model_choice):
|
|
|
81 |
if file is None:
|
82 |
# Construct the query
|
83 |
marketplaces = []
|
|
|
106 |
|
107 |
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
|
108 |
data.set_index('ds', inplace=True)
|
|
|
109 |
else:
|
110 |
data, period_type, period_col = read_and_process_file(file)
|
111 |
|
|
|
|
|
112 |
if period_type == "Month":
|
113 |
year = 12
|
114 |
n_periods = 24
|
115 |
freq = "MS"
|
116 |
else:
|
117 |
year = 52
|
118 |
+
n_periods = year * 2
|
119 |
freq = "W"
|
120 |
|
|
|
121 |
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
|
122 |
|
123 |
+
if model_choice == "Prophet":
|
124 |
+
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year)
|
125 |
+
elif model_choice == "Chronos":
|
126 |
+
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year)
|
127 |
+
else:
|
128 |
+
raise ValueError("Invalid model choice")
|
129 |
+
|
130 |
+
# Create Plotly figure (common for both models)
|
131 |
+
fig = create_plot(data, forecast)
|
132 |
+
|
133 |
+
# Combine original data and forecast
|
134 |
+
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1)
|
135 |
+
|
136 |
+
# Save combined data
|
137 |
+
combined_file = 'combined_data.csv'
|
138 |
+
combined_df.to_csv(combined_file)
|
139 |
+
|
140 |
+
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
|
141 |
+
|
142 |
+
def forecast_prophet(df, n_periods, freq, year):
|
143 |
model = Prophet()
|
144 |
model.fit(df)
|
|
|
|
|
145 |
future = model.make_future_dataframe(periods=n_periods, freq=freq)
|
|
|
|
|
146 |
forecast = model.predict(future)
|
147 |
+
|
|
|
148 |
sum_last_year_original = df['y'].iloc[-year:].sum()
|
149 |
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
|
150 |
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
|
151 |
+
|
152 |
+
return forecast, yoy_change
|
153 |
+
|
154 |
+
def forecast_chronos(df, n_periods, freq, year):
|
155 |
+
pipeline = ChronosPipeline.from_pretrained(
|
156 |
+
"amazon/chronos-t5-mini",
|
157 |
+
device_map="cpu",
|
158 |
+
torch_dtype=torch.bfloat16,
|
159 |
+
)
|
160 |
+
|
161 |
+
# Check for non-numeric values
|
162 |
+
if not pd.api.types.is_numeric_dtype(df['y']):
|
163 |
+
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()]
|
164 |
+
if not non_numeric.empty:
|
165 |
+
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}"
|
166 |
+
raise ValueError(error_message)
|
167 |
+
|
168 |
+
try:
|
169 |
+
y_values = df['y'].values.astype(np.float32)
|
170 |
+
except ValueError as e:
|
171 |
+
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}")
|
172 |
+
|
173 |
+
chronos_forecast = pipeline.predict(
|
174 |
+
context=torch.tensor(y_values),
|
175 |
+
prediction_length=n_periods,
|
176 |
+
num_samples=20,
|
177 |
+
limit_prediction_length=False
|
178 |
+
)
|
179 |
+
|
180 |
+
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:]
|
181 |
+
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
|
182 |
+
|
183 |
+
forecast = pd.DataFrame({
|
184 |
+
'ds': forecast_index,
|
185 |
+
'yhat': median,
|
186 |
+
'yhat_lower': low,
|
187 |
+
'yhat_upper': high
|
188 |
+
})
|
189 |
+
|
190 |
+
sum_last_year_original = df['y'].iloc[-year:].sum()
|
191 |
+
sum_first_year_forecast = median[:year].sum()
|
192 |
+
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
|
193 |
+
|
194 |
+
return forecast, yoy_change
|
195 |
|
196 |
+
def create_plot(data, forecast):
|
197 |
fig = go.Figure()
|
198 |
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
|
199 |
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
|
200 |
+
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI'))
|
201 |
+
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI'))
|
202 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
fig.update_layout(
|
204 |
title='Observed Time Series and Forecast with Confidence Intervals',
|
205 |
xaxis_title='Date',
|
206 |
yaxis_title='Values',
|
207 |
+
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
hovermode='x unified'
|
209 |
)
|
210 |
+
|
211 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
# Create Gradio interface using Blocks
|
214 |
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
|
|
|
225 |
with gr.Row():
|
226 |
product_name_input = gr.Textbox(label="Product Name Filter", value="product_name like '%пуховик%'")
|
227 |
|
228 |
+
with gr.Row():
|
229 |
+
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet")
|
230 |
+
|
231 |
with gr.Row():
|
232 |
compute_button = gr.Button("Compute")
|
233 |
|
|
|
240 |
with gr.Row():
|
241 |
csv_output = gr.File(label="Download Combined Data CSV")
|
242 |
|
243 |
+
compute_button.click(
|
244 |
+
forecast_time_series,
|
245 |
+
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice],
|
246 |
+
outputs=[plot_output, yoy_output, csv_output]
|
247 |
+
)
|
248 |
|
249 |
# Launch the interface
|
250 |
+
interface.launch(debug=True)
|
requirements.txt
CHANGED
@@ -4,3 +4,4 @@ pandas
|
|
4 |
plotly
|
5 |
prophet
|
6 |
psycopg2
|
|
|
|
4 |
plotly
|
5 |
prophet
|
6 |
psycopg2
|
7 |
+
git+https://github.com/amazon-science/chronos-forecasting.git
|