loading state dict runtime error
#4
by
leniad
- opened
Hey I am testing TimesFM for my project, everything was working fine until today. I didn't change the code, in fact it is deafult code from your example
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch", torch_compile=True)
RuntimeError: Error(s) in loading state_dict for TimesFM_2p5_200M_torch_module:
Missing key(s) in state_dict: "stacked_xf.0.attn.query.weight", "stacked_xf.0.attn.key.weight", "stacked_xf.0.attn.value.weight", "stacked_xf.1.attn.query.weight", "stacked_xf.1.attn.key.weight", "stacked_xf.1.attn.value.weight", "stacked_xf.2.attn.query.weight", "stacked_xf.2.attn.key.weight", "stacked_xf.2.attn.value.weight", "stacked_xf.3.attn.query.weight", "stacked_xf.3.attn.key.weight", "stacked_xf.3.attn.value.weight", "stacked_xf.4.attn.query.weight", "stacked_xf.4.attn.key.weight", "stacked_xf.4.attn.value.weight", "stacked_xf.5.attn.query.weight", "stacked_xf.5.attn.key.weight", "stacked_xf.5.attn.value.weight", "stacked_xf.6.attn.query.weight", "stacked_xf.6.attn.key.weight", "stacked_xf.6.attn.value.weight", "stacked_xf.7.attn.query.weight", "stacked_xf.7.attn.key.weight", "stacked_xf.7.attn.value.weight", "stacked_xf.8.attn.query.weight", "stacked_xf.8.attn.key.weight", "stacked_xf.8.attn.value.weight", "stacked_xf.9.attn.query.weight", "stacked_xf.9.attn.key.weight", "stacked_xf.9.attn.value.weight", "stacked_xf.10.attn.query.weight", "stacked_xf.10.attn.key.weight", "stacked_xf.10.attn.value.weight", "stacked_xf.11.attn.query.weight", "stacked_xf.11.attn.key.weight", "stacked_xf.11.attn.value.weight", "stacked_xf.12.attn.query.weight", "stacked_xf.12.attn.key.weight", "stacked_xf.12.attn.value.weight", "stacked_xf.13.attn.query.weight", "stacked_xf.13.attn.key.weight", "stacked_xf.13.attn.value.weight", "stacked_xf.14.attn.query.weight", "stacked_xf.14.attn.key.weight", "stacked_xf.14.attn.value.weight", "stacked_xf.15.attn.query.weight", "stacked_xf.15.attn.key.weight", "stacked_xf.15.attn.value.weight", "stacked_xf.16.attn.query.weight", "stacked_xf.16.attn.key.weight", "stacked_xf.16.attn.value.weight", "stacked_xf.17.attn.query.weight", "stacked_xf.17.attn.key.weight", "stacked_xf.17.attn.value.weight", "stacked_xf.18.attn.query.weight", "stacked_xf.18.attn.key.weight", "stacked_xf.18.attn.value.weight", "stacked_xf.19.attn.query.weight", "stacked_xf.19.attn.key.weight", "stacked_xf.19.attn.value.weight".
Unexpected key(s) in state_dict: "stacked_xf.0.attn.qkv_proj.weight", "stacked_xf.1.attn.qkv_proj.weight", "stacked_xf.2.attn.qkv_proj.weight", "stacked_xf.3.attn.qkv_proj.weight", "stacked_xf.4.attn.qkv_proj.weight", "stacked_xf.5.attn.qkv_proj.weight", "stacked_xf.6.attn.qkv_proj.weight", "stacked_xf.7.attn.qkv_proj.weight", "stacked_xf.8.attn.qkv_proj.weight", "stacked_xf.9.attn.qkv_proj.weight", "stacked_xf.10.attn.qkv_proj.weight", "stacked_xf.11.attn.qkv_proj.weight", "stacked_xf.12.attn.qkv_proj.weight", "stacked_xf.13.attn.qkv_proj.weight", "stacked_xf.14.attn.qkv_proj.weight", "stacked_xf.15.attn.qkv_proj.weight", "stacked_xf.16.attn.qkv_proj.weight", "stacked_xf.17.attn.qkv_proj.weight", "stacked_xf.18.attn.qkv_proj.weight", "stacked_xf.19.attn.qkv_proj.weight".
This version is also not working
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")
I reinstalled the package and it's working correctly now.
Hi @leniad , we did some speed optimization that fuses the QKV projection into one matrix: https://github.com/google-research/timesfm/pull/316
So we changed the model and the code. You might need to pull from main again and reinstall the package. Please lmk if that works.
rajatsen91
changed discussion status to
closed