Change `wte` to use shared embedding
#43
by
bcui19
- opened
Change wte
to use shared embedding
daking
changed pull request status to
merged
This change leads to the following error with torch==0.2.1
:
"/home/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-instruct/_____/modeling_mpt.py", line 271, in forward
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
File "/home/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given