ryanrahmadifa
Added files
79e1719
raw
history blame
3.68 kB
import pickle
from neuralforecast import NeuralForecast
from modules.transform import calendarFeatures
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
def main(artifacts_path, variate='uni', horizon_len=5):
with open(os.path.join(artifacts_path, 'transformations.pkl'), 'rb') as fp:
transformations_loaded = pickle.load(fp)
nf2 = NeuralForecast.load(path=os.path.join(artifacts_path, 'model'))
nf2.freq = 'B'
training_data_path = os.path.join(artifacts_path, 'train_data.csv')
Y_df = pd.read_csv(training_data_path)
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
futr_df = nf2.make_future_dataframe()
if variate=='univariate':
futr_df = calendarFeatures(futr_df)
Y_hat = nf2.predict(futr_df=futr_df).reset_index()
plot_data = Y_df.copy()
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(1, 1, figsize=(20, 7))
sns.lineplot(data=plot_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
for col in Y_hat.columns:
if col not in ['ds', 'unique_id']:
sns.lineplot(data=Y_hat, x='ds', y=col, label=col, ax=ax, linewidth=2)
ax.set_title('Forecasting Result', fontsize=22)
ax.set_ylabel('Y', fontsize=20)
ax.set_xlabel('Date', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid(True)
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
plt.savefig(plot_path, format='jpg', dpi=300)
if variate=='multivariate':
Y_hat = nf2.predict(futr_df=futr_df).reset_index()
plot_data = Y_df.copy()
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]
sns.set_theme(style="whitegrid")
unique_ids = Y_hat['unique_id'].unique()
fig, axes = plt.subplots(len(unique_ids), 1, figsize=(20, 7 * len(unique_ids)), sharex=True)
if len(unique_ids) == 1:
axes = [axes]
for ax, uid in zip(axes, unique_ids):
actual_data = plot_data[plot_data['unique_id'] == uid]
forecast_data = Y_hat[Y_hat['unique_id'] == uid]
sns.lineplot(data=actual_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
for col in Y_hat.columns:
if col not in ['ds', 'unique_id']:
sns.lineplot(data=forecast_data, x='ds', y=col, label=col, ax=ax, linewidth=2)
ax.set_title(f'Forecasting Result for {uid}', fontsize=22)
ax.set_ylabel('Y', fontsize=20)
ax.set_xlabel('Date', fontsize=20)
ax.legend(prop={'size': 15})
ax.grid(True)
plt.tight_layout()
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
plt.savefig(plot_path, format='jpg', dpi=300)
Y_hat.to_csv(os.path.join(artifacts_path, 'prediction_results.csv'))
if __name__ == "__main__":
directory = Path(__file__).parent.absolute()
run_id = 'crude_oil_d2eea033-7bf9-437d-809c-d18f574a982d' # This should be inputted (unique id for every run)
artifacts_path = os.path.join(directory, 'artifacts', run_id)
log_dir = os.path.join(artifacts_path, 'prediction_logs')
Trainer.default_root_dir = log_dir
logger = CSVLogger(save_dir=log_dir, name='forecast_logs')
main(artifacts_path=artifacts_path, variate='univariate', horizon_len=5)