import os import torch # import models from models.LSTM import LSTM from models.LSTNet import LSTNet from models.Transformer import Transformer from models.Autoformer import Autoformer from models.Informer import Informer from models.PatchTST import PatchTST from models.TimesNet import TimesNet from models.TimesFM import TimesFM # import keyword args from model_kwargs import * # set lookback and lookahead. lookback is fixed to 512, while lookahead can be one among 4, 48, 96 # heterogeneity can be 'HET' or 'HOM' lookback, lookahead, heterogeneity = 512, 48, 'HET' if __name__ == "__main__": models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM] kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs] # loop over models and their keyword functions for model_class, kw_fn in zip(models,kw_fns): # load an object of the model class model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead)) # load the weight in the model result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu')) # print the outcome print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.")