感谢您开源的模型,我尝试在wenet中使用它,但遇到一些问题,希望能得到您的解答

#1
by zombie1315 - opened

https://github.com/wenet-e2e/wenet/tree/1269a6e5bbec440302e934f243f623baeebf2758/examples/aishell/s0_ssl ,在复现您的这个例子中,使用您提供的模型,报错如下:
Traceback (most recent call last):
File "wenet/bin/train.py", line 322, in
main()
File "wenet/bin/train.py", line 234, in main
infos = load_trained_modules(model, args)
File "/home/wenet_ssl/wenet/utils/checkpoint.py", line 95, in load_trained_modules
model.load_state_dict(main_state_dict)
File "/home/miniconda3/envs/wenet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Wav2vec2Model:
Unexpected key(s) in state_dict: "encoder.embed.linear.weight", "encoder.embed.linear.bias".
size mismatch for encoder.embed.conv.2.weight: copying a param with shape torch.Size([512, 512, 5, 5]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
我该如何修改

zombie1315 changed discussion status to closed

Sign up or log in to comment