{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a3925b02-d72f-4936-aa50-bfa50c2487ec", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from neuralforecast.core import NeuralForecast\n", "from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate, NBEATSx\n", "from neuralforecast.losses.pytorch import MSE, MAE, MAPE\n", "from sklearn.preprocessing import StandardScaler\n", "import numpy as np\n", "import os" ] }, { "cell_type": "code", "execution_count": 18, "id": "b28bd72a-d97a-4511-8130-3f0bb4fdccd4", "metadata": {}, "outputs": [], "source": [ "# Functions\n", "\n", "def createLag(data, amt=10):\n", " \"\"\"\n", " Create a lag inside dataframe, in business days\n", "\n", " Input:\n", " data -> Pandas dataframe \n", " amt -> int\n", "\n", " Output:\n", " Copy of pandas Dataframe\n", " \"\"\"\n", " if 'ds' in data:\n", " copy = data.copy()\n", " copy['ds'] = copy['ds'] + pd.tseries.offsets.BusinessDay(amt)\n", " return copy\n", " else:\n", " print(f\"No 'ds' column found inside dataframe\")\n", " return data\n", "\n", "def trainTestValSplit(data, test_size, val_size):\n", " \"\"\"\n", " Splits data into train-test-validation sets\n", "\n", " Input:\n", " data -> Pandas dataframe\n", " test_size -> Proportion of data for test set\n", " val_size -> Proportiion of data fro validation set\n", "\n", " Output:\n", " This is not needed yet, actually\n", " \"\"\"\n", " pass\n", "\n", "def scaleStandard(df_col):\n", " \"\"\"\n", " Fits and returns a standard scaled version of a dataframe column\n", " \"\"\"\n", " scaler = StandardScaler()\n", " df_col = scaler.fit_transform(df_col)\n", " df_col = pd.DataFrame(df_col)\n", " return df_col, scaler\n", "\n", "def logReturn(data, df_col):\n", " \"\"\"\n", " Perform log return for a dataframe column\n", " \"\"\"\n", " new_col = np.log1p(data[df_col].pct_change())\n", " return new_col\n", "\n", "def transformData(data, log_return = [], standard_scale = []):\n", " \"\"\"\n", " Perform essential transformations towards the dataframe\n", " \"\"\"\n", " y_log_ret = False\n", " y_std_scale = False\n", "\n", " data.sort_values(by='ds', inplace=True)\n", "\n", " if len(log_return) > 0:\n", " \n", " for col1 in log_return:\n", " try:\n", " #print(data[col1])\n", " data[col1] = logReturn(data, col1)\n", " except Exception as e:\n", " print(e)\n", " pass\n", " \n", " if 'y' in log_return:\n", " y_log_ret = True\n", "\n", " if len(standard_scale) > 0:\n", " \n", " for col2 in standard_scale:\n", " try:\n", " data[col2], _ = scaleStandard(data[[col2]])\n", " except Exception as e:\n", " print(e)\n", " pass\n", " \n", " if 'y' in standard_scale:\n", " data['y'], yScaler = scaleStandard(data[['y']])\n", " y_std_scale = True\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 19, "id": "dc3678db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total length is 1962, with validation and test size of 49 for each\n" ] } ], "source": [ "# Exogenous \n", "\n", "Y_df = pd.read_csv(os.path.join('dataset', 'DatedBrent', 'priceForecast_nosent_1.csv')\n", " ).rename({'date' : 'ds', 'BrDa' : 'y'}, axis=1\n", " ).drop(columns=['Unnamed: 0'])\n", "Y_df['unique_id'] = 'Dated'\n", "Y_df['ds'] = pd.to_datetime(Y_df['ds'])\n", "\n", "# We make validation and test splits\n", "n_time = len(Y_df.ds.unique())\n", "val_size = int(.025 * n_time)\n", "test_size = int(.025 * n_time)\n", "\n", "print(f'Total length is {n_time}, with validation and test size of {val_size} for each')" ] }, { "cell_type": "code", "execution_count": 20, "id": "41f4b728", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dsBrFuGasOVXDXYGPRDyunique_id
02024-08-0273.522.317633.419998103.209999207.79943878.265Dated
12024-08-0176.932.398033.689999104.419998139.87809881.825Dated
22024-07-3177.912.442531.490000104.099998135.20684881.450Dated
32024-07-3074.732.344330.820000104.55000395.69639679.140Dated
42024-07-2677.162.417026.490000104.320000105.65412981.245Dated
...........................
19572014-08-1597.352.575017.86000181.419998131.822342100.775Dated
19582014-08-1495.582.540317.19000181.589996106.742241100.445Dated
19592014-08-1397.592.626617.54999981.599998165.846603101.855Dated
19602014-08-1297.372.604517.54000181.500000150.571259101.725Dated
19612014-08-0897.652.753718.85000081.389999212.396439103.315Dated
\n", "

1962 rows × 8 columns

\n", "
" ], "text/plain": [ " ds BrFu Gas OVX DXY GPRD y \\\n", "0 2024-08-02 73.52 2.3176 33.419998 103.209999 207.799438 78.265 \n", "1 2024-08-01 76.93 2.3980 33.689999 104.419998 139.878098 81.825 \n", "2 2024-07-31 77.91 2.4425 31.490000 104.099998 135.206848 81.450 \n", "3 2024-07-30 74.73 2.3443 30.820000 104.550003 95.696396 79.140 \n", "4 2024-07-26 77.16 2.4170 26.490000 104.320000 105.654129 81.245 \n", "... ... ... ... ... ... ... ... \n", "1957 2014-08-15 97.35 2.5750 17.860001 81.419998 131.822342 100.775 \n", "1958 2014-08-14 95.58 2.5403 17.190001 81.589996 106.742241 100.445 \n", "1959 2014-08-13 97.59 2.6266 17.549999 81.599998 165.846603 101.855 \n", "1960 2014-08-12 97.37 2.6045 17.540001 81.500000 150.571259 101.725 \n", "1961 2014-08-08 97.65 2.7537 18.850000 81.389999 212.396439 103.315 \n", "\n", " unique_id \n", "0 Dated \n", "1 Dated \n", "2 Dated \n", "3 Dated \n", "4 Dated \n", "... ... \n", "1957 Dated \n", "1958 Dated \n", "1959 Dated \n", "1960 Dated \n", "1961 Dated \n", "\n", "[1962 rows x 8 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_df" ] }, { "cell_type": "code", "execution_count": 21, "id": "11685932-4158-49d9-861c-0b917d9f45e9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dsyBrFuGasOVXDXYGPRDBrDaunique_id
02024-08-0278.26580.732.496624.230000105.800003128.75705087.015Dated
12024-07-3079.14081.572.465224.170000105.260002144.02468985.280Dated
22024-07-2681.24578.452.387724.450001105.55000390.60014381.730Dated
32024-07-2582.65578.622.402024.799999105.199997134.21084681.640Dated
42024-07-2482.99578.502.394425.540001104.650002118.30599281.245Dated
..............................
18122014-09-2695.25097.352.575017.86000181.419998131.822342100.775Dated
18132014-09-2595.46595.582.540317.19000181.589996106.742241100.445Dated
18142014-09-2494.58597.592.626617.54999981.599998165.846603101.855Dated
18152014-09-2395.06597.372.604517.54000181.500000150.571259101.725Dated
18162014-09-1996.64597.652.753718.85000081.389999212.396439103.315Dated
\n", "

1817 rows × 9 columns

\n", "
" ], "text/plain": [ " ds y BrFu Gas OVX DXY GPRD \\\n", "0 2024-08-02 78.265 80.73 2.4966 24.230000 105.800003 128.757050 \n", "1 2024-07-30 79.140 81.57 2.4652 24.170000 105.260002 144.024689 \n", "2 2024-07-26 81.245 78.45 2.3877 24.450001 105.550003 90.600143 \n", "3 2024-07-25 82.655 78.62 2.4020 24.799999 105.199997 134.210846 \n", "4 2024-07-24 82.995 78.50 2.3944 25.540001 104.650002 118.305992 \n", "... ... ... ... ... ... ... ... \n", "1812 2014-09-26 95.250 97.35 2.5750 17.860001 81.419998 131.822342 \n", "1813 2014-09-25 95.465 95.58 2.5403 17.190001 81.589996 106.742241 \n", "1814 2014-09-24 94.585 97.59 2.6266 17.549999 81.599998 165.846603 \n", "1815 2014-09-23 95.065 97.37 2.6045 17.540001 81.500000 150.571259 \n", "1816 2014-09-19 96.645 97.65 2.7537 18.850000 81.389999 212.396439 \n", "\n", " BrDa unique_id \n", "0 87.015 Dated \n", "1 85.280 Dated \n", "2 81.730 Dated \n", "3 81.640 Dated \n", "4 81.245 Dated \n", "... ... ... \n", "1812 100.775 Dated \n", "1813 100.445 Dated \n", "1814 101.855 Dated \n", "1815 101.725 Dated \n", "1816 103.315 Dated \n", "\n", "[1817 rows x 9 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_df_test = Y_df.copy().rename({'y' : 'BrDa'}, axis=1)\n", "Y_df_test = createLag(Y_df_test, amt=30)\n", "df = Y_df[['ds', 'y']].merge(Y_df_test, on = 'ds')\n", "df" ] }, { "cell_type": "code", "execution_count": 22, "id": "8dacad81-f2ed-48ee-a72b-8a3efdb4f112", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dsyBrFuGasOVXDXYGPRDBrDaunique_id
18152014-09-231.448573-0.002872-0.05570517.5400010.001351150.571259-0.015509Dated
18142014-09-241.4244940.0022570.00845017.5499990.001226165.8466030.001277Dated
18132014-09-251.468639-0.020811-0.03340817.190001-0.000123106.742241-0.013940Dated
18122014-09-261.4578540.0183490.01356717.860001-0.002086131.8223420.003280Dated
18112014-09-301.435279-0.029925-0.00443719.5200000.005634216.154495-0.010825Dated
..............................
42024-07-240.8430810.007673-0.00603825.540001-0.005527118.3059920.006235Dated
32024-07-250.8260240.0015270.00316924.7999990.005242134.2108460.004850Dated
22024-07-260.755292-0.002165-0.00597124.4500010.00332290.6001430.001102Dated
12024-07-300.6496940.0390000.03194224.170000-0.002751144.0246890.042519Dated
02024-08-020.605800-0.0103510.01265724.2300000.005117128.7570500.020141Dated
\n", "

1816 rows × 9 columns

\n", "
" ], "text/plain": [ " ds y BrFu Gas OVX DXY \\\n", "1815 2014-09-23 1.448573 -0.002872 -0.055705 17.540001 0.001351 \n", "1814 2014-09-24 1.424494 0.002257 0.008450 17.549999 0.001226 \n", "1813 2014-09-25 1.468639 -0.020811 -0.033408 17.190001 -0.000123 \n", "1812 2014-09-26 1.457854 0.018349 0.013567 17.860001 -0.002086 \n", "1811 2014-09-30 1.435279 -0.029925 -0.004437 19.520000 0.005634 \n", "... ... ... ... ... ... ... \n", "4 2024-07-24 0.843081 0.007673 -0.006038 25.540001 -0.005527 \n", "3 2024-07-25 0.826024 0.001527 0.003169 24.799999 0.005242 \n", "2 2024-07-26 0.755292 -0.002165 -0.005971 24.450001 0.003322 \n", "1 2024-07-30 0.649694 0.039000 0.031942 24.170000 -0.002751 \n", "0 2024-08-02 0.605800 -0.010351 0.012657 24.230000 0.005117 \n", "\n", " GPRD BrDa unique_id \n", "1815 150.571259 -0.015509 Dated \n", "1814 165.846603 0.001277 Dated \n", "1813 106.742241 -0.013940 Dated \n", "1812 131.822342 0.003280 Dated \n", "1811 216.154495 -0.010825 Dated \n", "... ... ... ... \n", "4 118.305992 0.006235 Dated \n", "3 134.210846 0.004850 Dated \n", "2 90.600143 0.001102 Dated \n", "1 144.024689 0.042519 Dated \n", "0 128.757050 0.020141 Dated \n", "\n", "[1816 rows x 9 columns]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data = transformData(df,\n", " log_return = ['BrFu', 'Gas', 'DXY', 'BrDa'],\n", " standard_scale = ['y'])\n", "test_data.dropna(inplace=True)\n", "test_data" ] }, { "cell_type": "markdown", "id": "2b862eb0-afca-489b-9289-0fa2efaaf1d4", "metadata": {}, "source": [ "---" ] }, { "cell_type": "code", "execution_count": 23, "id": "0d42f7f7-8c6f-4708-88e8-f4e4fdd081db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Seed set to 12345678\n", "Seed set to 12345678\n", "Seed set to 12345678\n" ] } ], "source": [ "horizon = 30\n", "input_size = 30*2\n", "models = [\n", " TSMixer(h=horizon,\n", " input_size=input_size,\n", " n_series=1,\n", " max_steps=1000,\n", " val_check_steps=100,\n", " early_stop_patience_steps=5,\n", " scaler_type='identity',\n", " loss=MAPE(),\n", " valid_loss=MAPE(),\n", " random_seed=12345678,\n", " ),\n", " TSMixerx(h=horizon,\n", " input_size=input_size,\n", " n_series=1,\n", " max_steps=1000,\n", " val_check_steps=100,\n", " early_stop_patience_steps=5,\n", " scaler_type='identity',\n", " dropout=0.7,\n", " loss=MAPE(),\n", " valid_loss=MAPE(),\n", " random_seed=12345678,\n", " futr_exog_list=['Gas', 'DXY', 'BrFu', 'BrDa'],\n", " ),\n", " NBEATSx(h=horizon,\n", " input_size=horizon,\n", " max_steps=1000,\n", " val_check_steps=100,\n", " early_stop_patience_steps=5,\n", " scaler_type='identity',\n", " loss=MAPE(),\n", " valid_loss=MAPE(),\n", " random_seed=12345678,\n", " futr_exog_list=['Gas', 'DXY', 'BrFu', 'BrDa']\n", " ),\n", "]" ] }, { "cell_type": "code", "execution_count": 24, "id": "1ef4261e-2256-4173-b6fa-a75a12062b11", "metadata": {}, "outputs": [], "source": [ "nf = NeuralForecast(\n", " models=models,\n", " freq='D')" ] }, { "cell_type": "code", "execution_count": 25, "id": "1157bb88-22fb-45c7-a36b-73149b873dd6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode \n", "-------------------------------------------------------------------\n", "0 | loss | MAPE | 0 | train\n", "1 | valid_loss | MAPE | 0 | train\n", "2 | padder | ConstantPad1d | 0 | train\n", "3 | scaler | TemporalNorm | 0 | train\n", "4 | norm | ReversibleInstanceNorm1d | 2 | train\n", "5 | mixing_layers | Sequential | 8.2 K | train\n", "6 | out | Linear | 1.8 K | train\n", "-------------------------------------------------------------------\n", "10.0 K Trainable params\n", "0 Non-trainable params\n", "10.0 K Total params\n", "0.040 Total estimated model params size (MB)\n", "29 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fda2e4a4cc8470d8c1b9682bf7080f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00]. Skipping setting a default `ModelSummary` callback.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c15d29b08454c02ba907c6acf2810d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00]. Skipping setting a default `ModelSummary` callback.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccd2baca479b4ce8801170f7fd9979cf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00]. Skipping setting a default `ModelSummary` callback.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dedfdcf9fdc3416fa0deeb8cba3d844e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
unique_iddscutoffTSMixerTSMixerxNBEATSxy
0Dated2024-04-102024-04-091.1772871.3907481.2212201.254183
1Dated2024-04-112024-04-091.1557551.3502571.2419361.285286
2Dated2024-04-122024-04-091.2236311.3337971.1557511.362289
3Dated2024-04-162024-04-091.1910171.3711891.0998871.259200
4Dated2024-04-172024-04-091.1832281.3626611.1619991.177431
........................
595Dated2024-07-242024-05-211.0902670.9981370.9349020.843081
596Dated2024-07-252024-05-211.0894880.9501860.9564400.826024
597Dated2024-07-262024-05-210.9827491.0212580.8852820.755292
598Dated2024-07-302024-05-211.0830790.9302270.8834450.649694
599Dated2024-08-022024-05-211.0331080.9680540.8570220.605800
\n", "

600 rows × 7 columns

\n", "" ], "text/plain": [ " unique_id ds cutoff TSMixer TSMixerx NBEATSx y\n", "0 Dated 2024-04-10 2024-04-09 1.177287 1.390748 1.221220 1.254183\n", "1 Dated 2024-04-11 2024-04-09 1.155755 1.350257 1.241936 1.285286\n", "2 Dated 2024-04-12 2024-04-09 1.223631 1.333797 1.155751 1.362289\n", "3 Dated 2024-04-16 2024-04-09 1.191017 1.371189 1.099887 1.259200\n", "4 Dated 2024-04-17 2024-04-09 1.183228 1.362661 1.161999 1.177431\n", ".. ... ... ... ... ... ... ...\n", "595 Dated 2024-07-24 2024-05-21 1.090267 0.998137 0.934902 0.843081\n", "596 Dated 2024-07-25 2024-05-21 1.089488 0.950186 0.956440 0.826024\n", "597 Dated 2024-07-26 2024-05-21 0.982749 1.021258 0.885282 0.755292\n", "598 Dated 2024-07-30 2024-05-21 1.083079 0.930227 0.883445 0.649694\n", "599 Dated 2024-08-02 2024-05-21 1.033108 0.968054 0.857022 0.605800\n", "\n", "[600 rows x 7 columns]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_hat_df" ] }, { "cell_type": "code", "execution_count": 30, "id": "10d24df3-901f-449a-ad2e-cb6a7a32ff74", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "Y_plot = Y_hat_df[Y_hat_df['unique_id']=='Dated']\n", "cutoffs = Y_hat_df['cutoff'].unique()[::horizon]\n", "Y_plot = Y_plot[Y_hat_df['cutoff'].isin(cutoffs)]\n", "\n", "plt.figure(figsize=(20,5))\n", "plt.plot(Y_plot['ds'], Y_plot['y'], label='True')\n", "for model in models:\n", " plt.plot(Y_plot['ds'], Y_plot[f'{model}'], label=f'{model}')\n", "plt.xlabel('Datestamp')\n", "plt.ylabel('OT')\n", "plt.grid()\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 29, "id": "55f84c08-6ffd-49d7-8798-4cbc5209e550", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TSMixer horizon 30 - MAE: 0.259\n", "TSMixer horizon 30 - MSE: 0.092\n", "TSMixer horizon 30 - MAPE: 0.363\n", "TSMixerx horizon 30 - MAE: 0.274\n", "TSMixerx horizon 30 - MSE: 0.113\n", "TSMixerx horizon 30 - MAPE: 0.367\n", "NBEATSx horizon 30 - MAE: 0.195\n", "NBEATSx horizon 30 - MSE: 0.067\n", "NBEATSx horizon 30 - MAPE: 0.234\n" ] } ], "source": [ "from neuralforecast.losses.numpy import mse, mae, mape\n", "\n", "for model in models:\n", " mae_model = mae(Y_hat_df['y'], Y_hat_df[f'{model}']) \n", " mse_model = mse(Y_hat_df['y'], Y_hat_df[f'{model}'])\n", " mape_model = mape(Y_hat_df['y'], Y_hat_df[f'{model}'])\n", " print(f'{model} horizon {horizon} - MAE: {mae_model:.3f}')\n", " print(f'{model} horizon {horizon} - MSE: {mse_model:.3f}')\n", " print(f'{model} horizon {horizon} - MAPE: {mape_model:.3f}')" ] }, { "cell_type": "code", "execution_count": 14, "id": "28baeafb-d96e-405d-b460-bd2001e7f160", "metadata": {}, "outputs": [], "source": [ "# import matplotlib.pyplot as plt\n", "# Y_plot = Y_hat_df[Y_hat_df['unique_id']=='Future']\n", "# cutoffs = Y_hat_df['cutoff'].unique()[::horizon]\n", "# Y_plot = Y_plot[Y_hat_df['cutoff'].isin(cutoffs)]\n", "\n", "# plt.figure(figsize=(20,5))\n", "# plt.plot(Y_plot['ds'], Y_plot['y'], label='True')\n", "# for model in models:\n", "# plt.plot(Y_plot['ds'], Y_plot[f'{model}'], label=f'{model}')\n", "# plt.xlabel('Datestamp')\n", "# plt.ylabel('OT')\n", "# plt.grid()\n", "# plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "id": "5be8e1f9-a92d-4edb-a992-8021b605333c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }