loading state dict runtime error

#4
by leniad - opened

Hey I am testing TimesFM for my project, everything was working fine until today. I didn't change the code, in fact it is deafult code from your example

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch", torch_compile=True)

RuntimeError: Error(s) in loading state_dict for TimesFM_2p5_200M_torch_module:
    Missing key(s) in state_dict: "stacked_xf.0.attn.query.weight", "stacked_xf.0.attn.key.weight", "stacked_xf.0.attn.value.weight", "stacked_xf.1.attn.query.weight", "stacked_xf.1.attn.key.weight", "stacked_xf.1.attn.value.weight", "stacked_xf.2.attn.query.weight", "stacked_xf.2.attn.key.weight", "stacked_xf.2.attn.value.weight", "stacked_xf.3.attn.query.weight", "stacked_xf.3.attn.key.weight", "stacked_xf.3.attn.value.weight", "stacked_xf.4.attn.query.weight", "stacked_xf.4.attn.key.weight", "stacked_xf.4.attn.value.weight", "stacked_xf.5.attn.query.weight", "stacked_xf.5.attn.key.weight", "stacked_xf.5.attn.value.weight", "stacked_xf.6.attn.query.weight", "stacked_xf.6.attn.key.weight", "stacked_xf.6.attn.value.weight", "stacked_xf.7.attn.query.weight", "stacked_xf.7.attn.key.weight", "stacked_xf.7.attn.value.weight", "stacked_xf.8.attn.query.weight", "stacked_xf.8.attn.key.weight", "stacked_xf.8.attn.value.weight", "stacked_xf.9.attn.query.weight", "stacked_xf.9.attn.key.weight", "stacked_xf.9.attn.value.weight", "stacked_xf.10.attn.query.weight", "stacked_xf.10.attn.key.weight", "stacked_xf.10.attn.value.weight", "stacked_xf.11.attn.query.weight", "stacked_xf.11.attn.key.weight", "stacked_xf.11.attn.value.weight", "stacked_xf.12.attn.query.weight", "stacked_xf.12.attn.key.weight", "stacked_xf.12.attn.value.weight", "stacked_xf.13.attn.query.weight", "stacked_xf.13.attn.key.weight", "stacked_xf.13.attn.value.weight", "stacked_xf.14.attn.query.weight", "stacked_xf.14.attn.key.weight", "stacked_xf.14.attn.value.weight", "stacked_xf.15.attn.query.weight", "stacked_xf.15.attn.key.weight", "stacked_xf.15.attn.value.weight", "stacked_xf.16.attn.query.weight", "stacked_xf.16.attn.key.weight", "stacked_xf.16.attn.value.weight", "stacked_xf.17.attn.query.weight", "stacked_xf.17.attn.key.weight", "stacked_xf.17.attn.value.weight", "stacked_xf.18.attn.query.weight", "stacked_xf.18.attn.key.weight", "stacked_xf.18.attn.value.weight", "stacked_xf.19.attn.query.weight", "stacked_xf.19.attn.key.weight", "stacked_xf.19.attn.value.weight". 
    Unexpected key(s) in state_dict: "stacked_xf.0.attn.qkv_proj.weight", "stacked_xf.1.attn.qkv_proj.weight", "stacked_xf.2.attn.qkv_proj.weight", "stacked_xf.3.attn.qkv_proj.weight", "stacked_xf.4.attn.qkv_proj.weight", "stacked_xf.5.attn.qkv_proj.weight", "stacked_xf.6.attn.qkv_proj.weight", "stacked_xf.7.attn.qkv_proj.weight", "stacked_xf.8.attn.qkv_proj.weight", "stacked_xf.9.attn.qkv_proj.weight", "stacked_xf.10.attn.qkv_proj.weight", "stacked_xf.11.attn.qkv_proj.weight", "stacked_xf.12.attn.qkv_proj.weight", "stacked_xf.13.attn.qkv_proj.weight", "stacked_xf.14.attn.qkv_proj.weight", "stacked_xf.15.attn.qkv_proj.weight", "stacked_xf.16.attn.qkv_proj.weight", "stacked_xf.17.attn.qkv_proj.weight", "stacked_xf.18.attn.qkv_proj.weight", "stacked_xf.19.attn.qkv_proj.weight". 

This version is also not working

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")

I reinstalled the package and it's working correctly now.

Google org

Hi @leniad , we did some speed optimization that fuses the QKV projection into one matrix: https://github.com/google-research/timesfm/pull/316

So we changed the model and the code. You might need to pull from main again and reinstall the package. Please lmk if that works.

rajatsen91 changed discussion status to closed

Sign up or log in to comment