Spaces:
Sleeping
Sleeping
File size: 3,678 Bytes
79e1719 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import pickle
from neuralforecast import NeuralForecast
from modules.transform import calendarFeatures
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
def main(artifacts_path, variate='uni', horizon_len=5):
with open(os.path.join(artifacts_path, 'transformations.pkl'), 'rb') as fp:
transformations_loaded = pickle.load(fp)
nf2 = NeuralForecast.load(path=os.path.join(artifacts_path, 'model'))
nf2.freq = 'B'
training_data_path = os.path.join(artifacts_path, 'train_data.csv')
Y_df = pd.read_csv(training_data_path)
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
futr_df = nf2.make_future_dataframe()
if variate=='univariate':
futr_df = calendarFeatures(futr_df)
Y_hat = nf2.predict(futr_df=futr_df).reset_index()
plot_data = Y_df.copy()
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
sns.lineplot(data=plot_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
for col in Y_hat.columns:
if col not in ['ds', 'unique_id']:
sns.lineplot(data=Y_hat, x='ds', y=col, label=col, ax=ax, linewidth=2)
ax.set_title('Forecasting Result', fontsize=22)
ax.set_ylabel('Y', fontsize=20)
ax.set_xlabel('Date', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid(True)
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
plt.savefig(plot_path, format='jpg', dpi=300)
if variate=='multivariate':
Y_hat = nf2.predict(futr_df=futr_df).reset_index()
plot_data = Y_df.copy()
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]
sns.set_theme(style="whitegrid")
unique_ids = Y_hat['unique_id'].unique()
fig, axes = plt.subplots(len(unique_ids), 1, figsize=(20, 7 * len(unique_ids)), sharex=True)
if len(unique_ids) == 1:
axes = [axes]
for ax, uid in zip(axes, unique_ids):
actual_data = plot_data[plot_data['unique_id'] == uid]
forecast_data = Y_hat[Y_hat['unique_id'] == uid]
sns.lineplot(data=actual_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
for col in Y_hat.columns:
if col not in ['ds', 'unique_id']:
sns.lineplot(data=forecast_data, x='ds', y=col, label=col, ax=ax, linewidth=2)
ax.set_title(f'Forecasting Result for {uid}', fontsize=22)
ax.set_ylabel('Y', fontsize=20)
ax.set_xlabel('Date', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid(True)
plt.tight_layout()
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
plt.savefig(plot_path, format='jpg', dpi=300)
Y_hat.to_csv(os.path.join(artifacts_path, 'prediction_results.csv'))
if __name__ == "__main__":
directory = Path(__file__).parent.absolute()
run_id = 'crude_oil_d2eea033-7bf9-437d-809c-d18f574a982d' # This should be inputted (unique id for every run)
artifacts_path = os.path.join(directory, 'artifacts', run_id)
log_dir = os.path.join(artifacts_path, 'prediction_logs')
Trainer.default_root_dir = log_dir
logger = CSVLogger(save_dir=log_dir, name='forecast_logs')
main(artifacts_path=artifacts_path, variate='univariate', horizon_len=5) |