Spaces:
Running
Running
File size: 938 Bytes
3a010aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import os
def prelude():
os.environ["PYTORCH_JIT"] = "0v"
# patch for jit script
# if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py
# patch it with `def expand_2d_or_3d_tensor(x: Tensor,`
FAIRSEQ_CODE = (
"/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py"
)
if os.path.exists(FAIRSEQ_CODE):
with open(FAIRSEQ_CODE, "r") as f:
lines = f.readlines()
with open(FAIRSEQ_CODE, "w") as f:
for line in lines:
if (
"def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):"
in line
):
f.write(
"def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n"
)
else:
f.write(line)
|