|
import unittest |
|
from onmt.translate import GeneratorLM |
|
import torch |
|
|
|
|
|
class TestGeneratorLM(unittest.TestCase): |
|
def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size( |
|
self, |
|
): |
|
src = torch.randint(0, 10, (6, 5, 1)) |
|
src_len = 5 * torch.ones(5, dtype=torch.int) |
|
( |
|
src, |
|
src_len, |
|
target_prefix, |
|
) = GeneratorLM.split_src_to_prevent_padding(src, src_len) |
|
self.assertIsNone(target_prefix) |
|
|
|
def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size( |
|
self, |
|
): |
|
default_length = 5 |
|
src = torch.randint(0, 10, (6, default_length, 1)) |
|
src_len = default_length * torch.ones(6, dtype=torch.int) |
|
new_length = 4 |
|
src_len[1] = new_length |
|
( |
|
src, |
|
src_len, |
|
target_prefix, |
|
) = GeneratorLM.split_src_to_prevent_padding(src, src_len) |
|
self.assertTupleEqual(src.shape, (6, new_length, 1)) |
|
self.assertTupleEqual(target_prefix.shape, (6, 1, 1)) |
|
self.assertTrue(src_len.equal(new_length * torch.ones(6, dtype=torch.int))) |
|
|