File size: 3,678 Bytes
79e1719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pickle
from neuralforecast import NeuralForecast
from modules.transform import calendarFeatures
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger

def main(artifacts_path, variate='uni', horizon_len=5):
    with open(os.path.join(artifacts_path, 'transformations.pkl'), 'rb') as fp:
        transformations_loaded = pickle.load(fp)
    
    nf2 = NeuralForecast.load(path=os.path.join(artifacts_path, 'model'))
    nf2.freq = 'B'

    training_data_path = os.path.join(artifacts_path, 'train_data.csv')
    Y_df = pd.read_csv(training_data_path)
    Y_df['ds'] = pd.to_datetime(Y_df['ds'])

    futr_df = nf2.make_future_dataframe()

    if variate=='univariate':
        futr_df = calendarFeatures(futr_df)
        Y_hat = nf2.predict(futr_df=futr_df).reset_index()
        plot_data = Y_df.copy()
        plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]

        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(20, 7))
        sns.lineplot(data=plot_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
        for col in Y_hat.columns:
            if col not in ['ds', 'unique_id']:
                sns.lineplot(data=Y_hat, x='ds', y=col, label=col, ax=ax, linewidth=2)
        
        ax.set_title('Forecasting Result', fontsize=22)
        ax.set_ylabel('Y', fontsize=20)
        ax.set_xlabel('Date', fontsize=20)
        ax.legend(prop={'size': 15})
        ax.grid(True)

        plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
        plt.savefig(plot_path, format='jpg', dpi=300)

    if variate=='multivariate':
        Y_hat = nf2.predict(futr_df=futr_df).reset_index()
        plot_data = Y_df.copy()
        plot_data = plot_data[plot_data['ds'] >= (plot_data['ds'].max() - pd.Timedelta(days=horizon_len * 2))]

        sns.set_theme(style="whitegrid")
        unique_ids = Y_hat['unique_id'].unique()
        fig, axes = plt.subplots(len(unique_ids), 1, figsize=(20, 7 * len(unique_ids)), sharex=True)
        if len(unique_ids) == 1:
            axes = [axes]

        for ax, uid in zip(axes, unique_ids):
            actual_data = plot_data[plot_data['unique_id'] == uid]
            forecast_data = Y_hat[Y_hat['unique_id'] == uid]

            sns.lineplot(data=actual_data, x='ds', y='y', label='Actual', ax=ax, color='blue', linewidth=2)
            for col in Y_hat.columns:
                if col not in ['ds', 'unique_id']:
                    sns.lineplot(data=forecast_data, x='ds', y=col, label=col, ax=ax, linewidth=2)

            ax.set_title(f'Forecasting Result for {uid}', fontsize=22)
            ax.set_ylabel('Y', fontsize=20)
            ax.set_xlabel('Date', fontsize=20)
            ax.legend(prop={'size': 15})
            ax.grid(True)

        plt.tight_layout()
        plot_path = os.path.join(artifacts_path, 'forecast_plot.jpg')
        plt.savefig(plot_path, format='jpg', dpi=300)
    Y_hat.to_csv(os.path.join(artifacts_path, 'prediction_results.csv'))

if __name__ == "__main__":
    directory = Path(__file__).parent.absolute()
    run_id = 'crude_oil_d2eea033-7bf9-437d-809c-d18f574a982d'  # This should be inputted (unique id for every run)
    artifacts_path = os.path.join(directory, 'artifacts', run_id)

    log_dir = os.path.join(artifacts_path, 'prediction_logs')
    Trainer.default_root_dir = log_dir
    
    logger = CSVLogger(save_dir=log_dir, name='forecast_logs')

    main(artifacts_path=artifacts_path, variate='univariate', horizon_len=5)