Shourya Bose
add timefm weights
a9073bb
raw
history blame
1.49 kB
import os
import torch
# import models
from models.LSTM import LSTM
from models.LSTNet import LSTNet
from models.Transformer import Transformer
from models.Autoformer import Autoformer
from models.Informer import Informer
from models.PatchTST import PatchTST
from models.TimesNet import TimesNet
from models.TimesFM import TimesFM
# import keyword args
from model_kwargs import *
# set lookback and lookahead. lookback is fixed to 512, while lookahead can be one among 4, 48, 96
# heterogeneity can be 'HET' or 'HOM'
lookback, lookahead, heterogeneity = 512, 48, 'HET'
if __name__ == "__main__":
models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM]
kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs]
# loop over models and their keyword functions
for model_class, kw_fn in zip(models,kw_fns):
# load an object of the model class
model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead))
# load the weight in the model
result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu'))
# print the outcome
print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.")