|
import os |
|
import torch |
|
|
|
|
|
from models.LSTM import LSTM |
|
from models.LSTNet import LSTNet |
|
from models.Transformer import Transformer |
|
from models.Autoformer import Autoformer |
|
from models.Informer import Informer |
|
from models.PatchTST import PatchTST |
|
from models.TimesNet import TimesNet |
|
from models.TimesFM import TimesFM |
|
|
|
|
|
from model_kwargs import * |
|
|
|
|
|
|
|
lookback, lookahead, heterogeneity = 512, 48, 'HET' |
|
|
|
if __name__ == "__main__": |
|
|
|
models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM] |
|
kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs] |
|
|
|
|
|
for model_class, kw_fn in zip(models,kw_fns): |
|
|
|
model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead)) |
|
|
|
result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu')) |
|
|
|
print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.") |
|
|