ReactSeq / onmt /tests /test_translator.py
Oopstom's picture
Upload 313 files
c668e80 verified
import unittest
from onmt.translate import GeneratorLM
import torch
class TestGeneratorLM(unittest.TestCase):
def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size( # noqa: E501
self,
):
src = torch.randint(0, 10, (6, 5, 1))
src_len = 5 * torch.ones(5, dtype=torch.int)
(
src,
src_len,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
self.assertIsNone(target_prefix)
def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size( # noqa: E501
self,
):
default_length = 5
src = torch.randint(0, 10, (6, default_length, 1))
src_len = default_length * torch.ones(6, dtype=torch.int)
new_length = 4
src_len[1] = new_length
(
src,
src_len,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
self.assertTupleEqual(src.shape, (6, new_length, 1))
self.assertTupleEqual(target_prefix.shape, (6, 1, 1))
self.assertTrue(src_len.equal(new_length * torch.ones(6, dtype=torch.int)))