{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a3925b02-d72f-4936-aa50-bfa50c2487ec", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from neuralforecast.core import NeuralForecast\n", "from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate, NBEATSx\n", "from neuralforecast.losses.pytorch import MSE, MAE, MAPE\n", "from sklearn.preprocessing import StandardScaler\n", "import numpy as np\n", "import os" ] }, { "cell_type": "code", "execution_count": 18, "id": "b28bd72a-d97a-4511-8130-3f0bb4fdccd4", "metadata": {}, "outputs": [], "source": [ "# Functions\n", "\n", "def createLag(data, amt=10):\n", " \"\"\"\n", " Create a lag inside dataframe, in business days\n", "\n", " Input:\n", " data -> Pandas dataframe \n", " amt -> int\n", "\n", " Output:\n", " Copy of pandas Dataframe\n", " \"\"\"\n", " if 'ds' in data:\n", " copy = data.copy()\n", " copy['ds'] = copy['ds'] + pd.tseries.offsets.BusinessDay(amt)\n", " return copy\n", " else:\n", " print(f\"No 'ds' column found inside dataframe\")\n", " return data\n", "\n", "def trainTestValSplit(data, test_size, val_size):\n", " \"\"\"\n", " Splits data into train-test-validation sets\n", "\n", " Input:\n", " data -> Pandas dataframe\n", " test_size -> Proportion of data for test set\n", " val_size -> Proportiion of data fro validation set\n", "\n", " Output:\n", " This is not needed yet, actually\n", " \"\"\"\n", " pass\n", "\n", "def scaleStandard(df_col):\n", " \"\"\"\n", " Fits and returns a standard scaled version of a dataframe column\n", " \"\"\"\n", " scaler = StandardScaler()\n", " df_col = scaler.fit_transform(df_col)\n", " df_col = pd.DataFrame(df_col)\n", " return df_col, scaler\n", "\n", "def logReturn(data, df_col):\n", " \"\"\"\n", " Perform log return for a dataframe column\n", " \"\"\"\n", " new_col = np.log1p(data[df_col].pct_change())\n", " return new_col\n", "\n", "def transformData(data, log_return = [], standard_scale = []):\n", " \"\"\"\n", " Perform essential transformations towards the dataframe\n", " \"\"\"\n", " y_log_ret = False\n", " y_std_scale = False\n", "\n", " data.sort_values(by='ds', inplace=True)\n", "\n", " if len(log_return) > 0:\n", " \n", " for col1 in log_return:\n", " try:\n", " #print(data[col1])\n", " data[col1] = logReturn(data, col1)\n", " except Exception as e:\n", " print(e)\n", " pass\n", " \n", " if 'y' in log_return:\n", " y_log_ret = True\n", "\n", " if len(standard_scale) > 0:\n", " \n", " for col2 in standard_scale:\n", " try:\n", " data[col2], _ = scaleStandard(data[[col2]])\n", " except Exception as e:\n", " print(e)\n", " pass\n", " \n", " if 'y' in standard_scale:\n", " data['y'], yScaler = scaleStandard(data[['y']])\n", " y_std_scale = True\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 19, "id": "dc3678db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total length is 1962, with validation and test size of 49 for each\n" ] } ], "source": [ "# Exogenous \n", "\n", "Y_df = pd.read_csv(os.path.join('dataset', 'DatedBrent', 'priceForecast_nosent_1.csv')\n", " ).rename({'date' : 'ds', 'BrDa' : 'y'}, axis=1\n", " ).drop(columns=['Unnamed: 0'])\n", "Y_df['unique_id'] = 'Dated'\n", "Y_df['ds'] = pd.to_datetime(Y_df['ds'])\n", "\n", "# We make validation and test splits\n", "n_time = len(Y_df.ds.unique())\n", "val_size = int(.025 * n_time)\n", "test_size = int(.025 * n_time)\n", "\n", "print(f'Total length is {n_time}, with validation and test size of {val_size} for each')" ] }, { "cell_type": "code", "execution_count": 20, "id": "41f4b728", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | ds | \n", "BrFu | \n", "Gas | \n", "OVX | \n", "DXY | \n", "GPRD | \n", "y | \n", "unique_id | \n", "
---|---|---|---|---|---|---|---|---|
0 | \n", "2024-08-02 | \n", "73.52 | \n", "2.3176 | \n", "33.419998 | \n", "103.209999 | \n", "207.799438 | \n", "78.265 | \n", "Dated | \n", "
1 | \n", "2024-08-01 | \n", "76.93 | \n", "2.3980 | \n", "33.689999 | \n", "104.419998 | \n", "139.878098 | \n", "81.825 | \n", "Dated | \n", "
2 | \n", "2024-07-31 | \n", "77.91 | \n", "2.4425 | \n", "31.490000 | \n", "104.099998 | \n", "135.206848 | \n", "81.450 | \n", "Dated | \n", "
3 | \n", "2024-07-30 | \n", "74.73 | \n", "2.3443 | \n", "30.820000 | \n", "104.550003 | \n", "95.696396 | \n", "79.140 | \n", "Dated | \n", "
4 | \n", "2024-07-26 | \n", "77.16 | \n", "2.4170 | \n", "26.490000 | \n", "104.320000 | \n", "105.654129 | \n", "81.245 | \n", "Dated | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1957 | \n", "2014-08-15 | \n", "97.35 | \n", "2.5750 | \n", "17.860001 | \n", "81.419998 | \n", "131.822342 | \n", "100.775 | \n", "Dated | \n", "
1958 | \n", "2014-08-14 | \n", "95.58 | \n", "2.5403 | \n", "17.190001 | \n", "81.589996 | \n", "106.742241 | \n", "100.445 | \n", "Dated | \n", "
1959 | \n", "2014-08-13 | \n", "97.59 | \n", "2.6266 | \n", "17.549999 | \n", "81.599998 | \n", "165.846603 | \n", "101.855 | \n", "Dated | \n", "
1960 | \n", "2014-08-12 | \n", "97.37 | \n", "2.6045 | \n", "17.540001 | \n", "81.500000 | \n", "150.571259 | \n", "101.725 | \n", "Dated | \n", "
1961 | \n", "2014-08-08 | \n", "97.65 | \n", "2.7537 | \n", "18.850000 | \n", "81.389999 | \n", "212.396439 | \n", "103.315 | \n", "Dated | \n", "
1962 rows × 8 columns
\n", "\n", " | ds | \n", "y | \n", "BrFu | \n", "Gas | \n", "OVX | \n", "DXY | \n", "GPRD | \n", "BrDa | \n", "unique_id | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "2024-08-02 | \n", "78.265 | \n", "80.73 | \n", "2.4966 | \n", "24.230000 | \n", "105.800003 | \n", "128.757050 | \n", "87.015 | \n", "Dated | \n", "
1 | \n", "2024-07-30 | \n", "79.140 | \n", "81.57 | \n", "2.4652 | \n", "24.170000 | \n", "105.260002 | \n", "144.024689 | \n", "85.280 | \n", "Dated | \n", "
2 | \n", "2024-07-26 | \n", "81.245 | \n", "78.45 | \n", "2.3877 | \n", "24.450001 | \n", "105.550003 | \n", "90.600143 | \n", "81.730 | \n", "Dated | \n", "
3 | \n", "2024-07-25 | \n", "82.655 | \n", "78.62 | \n", "2.4020 | \n", "24.799999 | \n", "105.199997 | \n", "134.210846 | \n", "81.640 | \n", "Dated | \n", "
4 | \n", "2024-07-24 | \n", "82.995 | \n", "78.50 | \n", "2.3944 | \n", "25.540001 | \n", "104.650002 | \n", "118.305992 | \n", "81.245 | \n", "Dated | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1812 | \n", "2014-09-26 | \n", "95.250 | \n", "97.35 | \n", "2.5750 | \n", "17.860001 | \n", "81.419998 | \n", "131.822342 | \n", "100.775 | \n", "Dated | \n", "
1813 | \n", "2014-09-25 | \n", "95.465 | \n", "95.58 | \n", "2.5403 | \n", "17.190001 | \n", "81.589996 | \n", "106.742241 | \n", "100.445 | \n", "Dated | \n", "
1814 | \n", "2014-09-24 | \n", "94.585 | \n", "97.59 | \n", "2.6266 | \n", "17.549999 | \n", "81.599998 | \n", "165.846603 | \n", "101.855 | \n", "Dated | \n", "
1815 | \n", "2014-09-23 | \n", "95.065 | \n", "97.37 | \n", "2.6045 | \n", "17.540001 | \n", "81.500000 | \n", "150.571259 | \n", "101.725 | \n", "Dated | \n", "
1816 | \n", "2014-09-19 | \n", "96.645 | \n", "97.65 | \n", "2.7537 | \n", "18.850000 | \n", "81.389999 | \n", "212.396439 | \n", "103.315 | \n", "Dated | \n", "
1817 rows × 9 columns
\n", "\n", " | ds | \n", "y | \n", "BrFu | \n", "Gas | \n", "OVX | \n", "DXY | \n", "GPRD | \n", "BrDa | \n", "unique_id | \n", "
---|---|---|---|---|---|---|---|---|---|
1815 | \n", "2014-09-23 | \n", "1.448573 | \n", "-0.002872 | \n", "-0.055705 | \n", "17.540001 | \n", "0.001351 | \n", "150.571259 | \n", "-0.015509 | \n", "Dated | \n", "
1814 | \n", "2014-09-24 | \n", "1.424494 | \n", "0.002257 | \n", "0.008450 | \n", "17.549999 | \n", "0.001226 | \n", "165.846603 | \n", "0.001277 | \n", "Dated | \n", "
1813 | \n", "2014-09-25 | \n", "1.468639 | \n", "-0.020811 | \n", "-0.033408 | \n", "17.190001 | \n", "-0.000123 | \n", "106.742241 | \n", "-0.013940 | \n", "Dated | \n", "
1812 | \n", "2014-09-26 | \n", "1.457854 | \n", "0.018349 | \n", "0.013567 | \n", "17.860001 | \n", "-0.002086 | \n", "131.822342 | \n", "0.003280 | \n", "Dated | \n", "
1811 | \n", "2014-09-30 | \n", "1.435279 | \n", "-0.029925 | \n", "-0.004437 | \n", "19.520000 | \n", "0.005634 | \n", "216.154495 | \n", "-0.010825 | \n", "Dated | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
4 | \n", "2024-07-24 | \n", "0.843081 | \n", "0.007673 | \n", "-0.006038 | \n", "25.540001 | \n", "-0.005527 | \n", "118.305992 | \n", "0.006235 | \n", "Dated | \n", "
3 | \n", "2024-07-25 | \n", "0.826024 | \n", "0.001527 | \n", "0.003169 | \n", "24.799999 | \n", "0.005242 | \n", "134.210846 | \n", "0.004850 | \n", "Dated | \n", "
2 | \n", "2024-07-26 | \n", "0.755292 | \n", "-0.002165 | \n", "-0.005971 | \n", "24.450001 | \n", "0.003322 | \n", "90.600143 | \n", "0.001102 | \n", "Dated | \n", "
1 | \n", "2024-07-30 | \n", "0.649694 | \n", "0.039000 | \n", "0.031942 | \n", "24.170000 | \n", "-0.002751 | \n", "144.024689 | \n", "0.042519 | \n", "Dated | \n", "
0 | \n", "2024-08-02 | \n", "0.605800 | \n", "-0.010351 | \n", "0.012657 | \n", "24.230000 | \n", "0.005117 | \n", "128.757050 | \n", "0.020141 | \n", "Dated | \n", "
1816 rows × 9 columns
\n", "\n", " | unique_id | \n", "ds | \n", "cutoff | \n", "TSMixer | \n", "TSMixerx | \n", "NBEATSx | \n", "y | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "Dated | \n", "2024-04-10 | \n", "2024-04-09 | \n", "1.177287 | \n", "1.390748 | \n", "1.221220 | \n", "1.254183 | \n", "
1 | \n", "Dated | \n", "2024-04-11 | \n", "2024-04-09 | \n", "1.155755 | \n", "1.350257 | \n", "1.241936 | \n", "1.285286 | \n", "
2 | \n", "Dated | \n", "2024-04-12 | \n", "2024-04-09 | \n", "1.223631 | \n", "1.333797 | \n", "1.155751 | \n", "1.362289 | \n", "
3 | \n", "Dated | \n", "2024-04-16 | \n", "2024-04-09 | \n", "1.191017 | \n", "1.371189 | \n", "1.099887 | \n", "1.259200 | \n", "
4 | \n", "Dated | \n", "2024-04-17 | \n", "2024-04-09 | \n", "1.183228 | \n", "1.362661 | \n", "1.161999 | \n", "1.177431 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
595 | \n", "Dated | \n", "2024-07-24 | \n", "2024-05-21 | \n", "1.090267 | \n", "0.998137 | \n", "0.934902 | \n", "0.843081 | \n", "
596 | \n", "Dated | \n", "2024-07-25 | \n", "2024-05-21 | \n", "1.089488 | \n", "0.950186 | \n", "0.956440 | \n", "0.826024 | \n", "
597 | \n", "Dated | \n", "2024-07-26 | \n", "2024-05-21 | \n", "0.982749 | \n", "1.021258 | \n", "0.885282 | \n", "0.755292 | \n", "
598 | \n", "Dated | \n", "2024-07-30 | \n", "2024-05-21 | \n", "1.083079 | \n", "0.930227 | \n", "0.883445 | \n", "0.649694 | \n", "
599 | \n", "Dated | \n", "2024-08-02 | \n", "2024-05-21 | \n", "1.033108 | \n", "0.968054 | \n", "0.857022 | \n", "0.605800 | \n", "
600 rows × 7 columns
\n", "