import pickle from neuralforecast import NeuralForecast from modules.transform import calendarFeatures import os import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from pytorch_lightning import Trainer from pytorch_lightning.loggers import CSVLogger def main(artifacts_path, variate='uni', horizon_len=5): with open(os.path.join(artifacts_path, 'transformations.pkl'), 'rb') as fp: transformations_loaded = pickle.load(fp) nf2 = NeuralForecast.load(path=os.path.join(artifacts_path, 'model')) nf2.freq = 'B' training_data_path = os.path.join(artifacts_path, 'train_data.csv') Y_df = pd.read_csv(training_data_path) Y_df['ds'] = pd.to_datetime(Y_df['ds']) futr_df = nf2.make_future_dataframe() if variate=='univariate': futr_df = calendarFeatures(futr_df) Y_hat = nf2.predict(futr_df=futr_df).reset_index() plot_data = Y_df.copy() plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))] sns.set_theme(style="whitegrid") fig, ax = plt.subplots(1, 1, figsize=(20, 7)) sns.lineplot(data=plot_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2) for col in Y_hat.columns: if col not in ['ds', 'unique_id']: sns.lineplot(data=Y_hat, x='ds', y=col, label=col, ax=ax, linewidth=2) ax.set_title('Forecasting Result', fontsize=22) ax.set_ylabel('Y', fontsize=20) ax.set_xlabel('Date', fontsize=20) ax.legend(prop={'size': 15}) ax.grid(True) plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg') plt.savefig(plot_path, format='jpg', dpi=300) if variate=='multivariate': Y_hat = nf2.predict(futr_df=futr_df).reset_index() plot_data = Y_df.copy() plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))] sns.set_theme(style="whitegrid") unique_ids = Y_hat['unique_id'].unique() fig, axes = plt.subplots(len(unique_ids), 1, figsize=(20, 7 * len(unique_ids)), sharex=True) if len(unique_ids) == 1: axes = [axes] for ax, uid in zip(axes, unique_ids): actual_data = plot_data[plot_data['unique_id'] == uid] forecast_data = Y_hat[Y_hat['unique_id'] == uid] sns.lineplot(data=actual_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2) for col in Y_hat.columns: if col not in ['ds', 'unique_id']: sns.lineplot(data=forecast_data, x='ds', y=col, label=col, ax=ax, linewidth=2) ax.set_title(f'Forecasting Result for {uid}', fontsize=22) ax.set_ylabel('Y', fontsize=20) ax.set_xlabel('Date', fontsize=20) ax.legend(prop={'size': 15}) ax.grid(True) plt.tight_layout() plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg') plt.savefig(plot_path, format='jpg', dpi=300) Y_hat.to_csv(os.path.join(artifacts_path, 'prediction_results.csv')) if __name__ == "__main__": directory = Path(__file__).parent.absolute() run_id = 'crude_oil_d2eea033-7bf9-437d-809c-d18f574a982d' # This should be inputted (unique id for every run) artifacts_path = os.path.join(directory, 'artifacts', run_id) log_dir = os.path.join(artifacts_path, 'prediction_logs') Trainer.default_root_dir = log_dir logger = CSVLogger(save_dir=log_dir, name='forecast_logs') main(artifacts_path=artifacts_path, variate='univariate', horizon_len=5)