File size: 1,185 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import unittest
from onmt.translate import GeneratorLM
import torch


class TestGeneratorLM(unittest.TestCase):
    def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size(  # noqa: E501
        self,
    ):
        src = torch.randint(0, 10, (6, 5, 1))
        src_len = 5 * torch.ones(5, dtype=torch.int)
        (
            src,
            src_len,
            target_prefix,
        ) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
        self.assertIsNone(target_prefix)

    def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size(  # noqa: E501
        self,
    ):
        default_length = 5
        src = torch.randint(0, 10, (6, default_length, 1))
        src_len = default_length * torch.ones(6, dtype=torch.int)
        new_length = 4
        src_len[1] = new_length
        (
            src,
            src_len,
            target_prefix,
        ) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
        self.assertTupleEqual(src.shape, (6, new_length, 1))
        self.assertTupleEqual(target_prefix.shape, (6, 1, 1))
        self.assertTrue(src_len.equal(new_length * torch.ones(6, dtype=torch.int)))