Spaces:
Sleeping
Sleeping
import pickle | |
from neuralforecast import NeuralForecast | |
from modules.transform import calendarFeatures | |
import os | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from pathlib import Path | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.loggers import CSVLogger | |
def main(artifacts_path, variate='uni', horizon_len=5): | |
with open(os.path.join(artifacts_path, 'transformations.pkl'), 'rb') as fp: | |
transformations_loaded = pickle.load(fp) | |
nf2 = NeuralForecast.load(path=os.path.join(artifacts_path, 'model')) | |
nf2.freq = 'B' | |
training_data_path = os.path.join(artifacts_path, 'train_data.csv') | |
Y_df = pd.read_csv(training_data_path) | |
Y_df['ds'] = pd.to_datetime(Y_df['ds']) | |
futr_df = nf2.make_future_dataframe() | |
if variate=='univariate': | |
futr_df = calendarFeatures(futr_df) | |
Y_hat = nf2.predict(futr_df=futr_df).reset_index() | |
plot_data = Y_df.copy() | |
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))] | |
sns.set_theme(style="whitegrid") | |
fig, ax = plt.subplots(1, 1, figsize=(20, 7)) | |
sns.lineplot(data=plot_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2) | |
for col in Y_hat.columns: | |
if col not in ['ds', 'unique_id']: | |
sns.lineplot(data=Y_hat, x='ds', y=col, label=col, ax=ax, linewidth=2) | |
ax.set_title('Forecasting Result', fontsize=22) | |
ax.set_ylabel('Y', fontsize=20) | |
ax.set_xlabel('Date', fontsize=20) | |
ax.legend(prop={'size': 15}) | |
ax.grid(True) | |
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg') | |
plt.savefig(plot_path, format='jpg', dpi=300) | |
if variate=='multivariate': | |
Y_hat = nf2.predict(futr_df=futr_df).reset_index() | |
plot_data = Y_df.copy() | |
plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))] | |
sns.set_theme(style="whitegrid") | |
unique_ids = Y_hat['unique_id'].unique() | |
fig, axes = plt.subplots(len(unique_ids), 1, figsize=(20, 7 * len(unique_ids)), sharex=True) | |
if len(unique_ids) == 1: | |
axes = [axes] | |
for ax, uid in zip(axes, unique_ids): | |
actual_data = plot_data[plot_data['unique_id'] == uid] | |
forecast_data = Y_hat[Y_hat['unique_id'] == uid] | |
sns.lineplot(data=actual_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2) | |
for col in Y_hat.columns: | |
if col not in ['ds', 'unique_id']: | |
sns.lineplot(data=forecast_data, x='ds', y=col, label=col, ax=ax, linewidth=2) | |
ax.set_title(f'Forecasting Result for {uid}', fontsize=22) | |
ax.set_ylabel('Y', fontsize=20) | |
ax.set_xlabel('Date', fontsize=20) | |
ax.legend(prop={'size': 15}) | |
ax.grid(True) | |
plt.tight_layout() | |
plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg') | |
plt.savefig(plot_path, format='jpg', dpi=300) | |
Y_hat.to_csv(os.path.join(artifacts_path, 'prediction_results.csv')) | |
if __name__ == "__main__": | |
directory = Path(__file__).parent.absolute() | |
run_id = 'crude_oil_d2eea033-7bf9-437d-809c-d18f574a982d' # This should be inputted (unique id for every run) | |
artifacts_path = os.path.join(directory, 'artifacts', run_id) | |
log_dir = os.path.join(artifacts_path, 'prediction_logs') | |
Trainer.default_root_dir = log_dir | |
logger = CSVLogger(save_dir=log_dir, name='forecast_logs') | |
main(artifacts_path=artifacts_path, variate='univariate', horizon_len=5) |