import unittest from onmt.transforms.bart import word_start_finder from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer from onmt.constants import SubwordMarker class TestWordStartFinder(unittest.TestCase): def test_word_start_naive(self): word_start_finder_fn = word_start_finder(ignore_subword=True) data_in = [ "however", ",", "according", "to", "the", "logs", ",", "she", "is", "hard", "-", "working", ".", ] # noqa: E501 true_out = [ True, True, True, True, True, True, True, True, True, True, True, True, True, ] # noqa: E501 out = word_start_finder_fn(data_in) self.assertEqual(out, true_out) def test_word_start_joiner(self): word_start_finder_fn = word_start_finder(is_joiner=True) data_in = [ "however", "■,", "according", "to", "the", "logs", "■,", "she", "is", "hard", "■-■", "working", "■.", ] # noqa: E501 true_out = [ True, False, True, True, True, True, False, True, True, True, False, False, False, ] # noqa: E501 out = word_start_finder_fn(data_in) self.assertEqual(out, true_out) def test_word_start_spacer(self): word_start_finder_fn = word_start_finder() data_in = [ "▁however", ",", "▁according", "▁to", "▁the", "▁logs", ",", "▁she", "▁is", "▁hard", "-", "working", ".", ] # noqa: E501 true_out = [ True, False, True, True, True, True, False, True, True, True, False, False, False, ] # noqa: E501 out = word_start_finder_fn(data_in) self.assertEqual(out, true_out) # no dummy prefix no_dummy = [ "however", ",", "▁according", "▁to", "▁the", "▁logs", ",", "▁she", "▁is", "▁hard", "-", "working", ".", ] # noqa: E501 no_dummy_out = word_start_finder_fn(no_dummy) self.assertEqual(no_dummy_out, true_out) class TestSubwordGroup(unittest.TestCase): def test_subword_group_joiner(self): data_in = [ "however", "■,", "according", "to", "the", "logs", "■,", "she", "is", "hard", "■-■", "working", "■.", ] # noqa: E501 true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7] out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER) self.assertEqual(out, true_out) def test_subword_group_joiner_with_case_markup(self): data_in = [ "⦅mrk_case_modifier_C⦆", "however", "■,", "according", "to", "the", "logs", "■,", "⦅mrk_begin_case_region_U⦆", "she", "is", "hard", "■-■", "working", "⦅mrk_end_case_region_U⦆", "■.", ] # noqa: E501 true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER) self.assertEqual(out, true_out) def test_subword_group_joiner_with_case_markup_advanced(self): data_in = [ "⦅mrk_case_modifier_C⦆", "dummy", "text", "⦅mrk_case_modifier_C⦆", "1■", "h■", "k", "⦅mrk_begin_case_region_U⦆", "th■", "⦅mrk_end_case_region_U⦆", "n", "more", "dummy", "text", ] # noqa: E501 true_out = [0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 5, 6] out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER) self.assertEqual(out, true_out) def test_subword_group_joiner_prior_tokenization(self): data_in = [ "⦅mrk_case_modifier_C⦆", "how■", "ever", "■,", "according", "to", "the", "logs", "■,", "⦅mrk_begin_case_region_U⦆", "she", "is", "hard", "■-■", "working", "⦅mrk_end_case_region_U⦆", "■.", ] # noqa: E501 original_data_in = [ "However", "■,", "according", "to", "the", "logs", "■,", "SHE", "IS", "HARD-WORKING", "■.", ] # noqa: E501 true_out = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 9, 10] # noqa: E501 out = subword_map_by_joiner( data_in, marker=SubwordMarker.JOINER, original_subwords=original_data_in ) self.assertEqual(out, true_out) def test_subword_group_joiner_prior_tokenization_harder(self): data_in = [ "⦅mrk_case_modifier_C⦆", "how■", "ever", "■,", "according", "to", "the", "logs", "■,", "⦅mrk_begin_case_region_U⦆", "she", "is", "hard", "■-■", "working", "⦅mrk_end_case_region_U⦆", "■.", ] # noqa: E501 original_data_in = [ "⦅mrk_case_modifier_C⦆", "how■", "ever", "■,", "according", "to", "the", "logs", "■,", "⦅mrk_begin_case_region_U⦆", "she", "is", "hard", "■-■", "working", "⦅mrk_end_case_region_U⦆", "■.", ] # noqa: E501 true_out = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ] # noqa: E501 out = subword_map_by_joiner( data_in, marker=SubwordMarker.JOINER, original_subwords=original_data_in ) self.assertEqual(out, true_out) def test_subword_group_joiner_with_new_joiner(self): data_in = [ "⦅mrk_case_modifier_C⦆", "however", "■", ",", "according", "to", "the", "logs", "■", ",", "⦅mrk_begin_case_region_U⦆", "she", "is", "hard", "■", "-", "■", "working", "⦅mrk_end_case_region_U⦆", "■", ".", ] # noqa: E501 true_out = [ 0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, ] # noqa: E501 out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER) self.assertEqual(out, true_out) def test_subword_group_naive(self): data_in = [ "however", ",", "according", "to", "the", "logs", ",", "she", "is", "hard", "-", "working", ".", ] # noqa: E501 true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER) self.assertEqual(out, true_out) def test_subword_group_spacer(self): data_in = [ "however", ",", "▁according", "▁to", "▁the", "▁logs", ",", "▁she", "▁is", "▁hard", "-", "working", ".", ] # noqa: E501 true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7] out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER) self.assertEqual(out, true_out) # no dummy prefix no_dummy = [ "however", ",", "▁according", "▁to", "▁the", "▁logs", ",", "▁she", "▁is", "▁hard", "-", "working", ".", ] # noqa: E501 no_dummy_out = subword_map_by_spacer(no_dummy, marker=SubwordMarker.SPACER) self.assertEqual(no_dummy_out, true_out) def test_subword_group_spacer_with_case_markup(self): data_in = [ "⦅mrk_case_modifier_C⦆", "▁however", ",", "▁according", "▁to", "▁the", "▁logs", ",", "▁⦅mrk_begin_case_region_U⦆", "▁she", "▁is", "▁hard", "-", "working", ".", "▁⦅mrk_end_case_region_U⦆", ] # noqa: E501 true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER) self.assertEqual(out, true_out) def test_subword_group_spacer_with_spacer_new(self): data_in = [ "⦅mrk_case_modifier_C⦆", "▁", "however", ",", "▁", "according", "▁", "to", "▁", "the", "▁", "logs", ",", "▁", "⦅mrk_begin_case_region_U⦆", "▁", "she", "▁", "is", "▁", "hard", "-", "working", ".", "▁", "⦅mrk_end_case_region_U⦆", ] # noqa: E501 true_out = [ 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, ] # noqa: E501 out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER) self.assertEqual(out, true_out) if __name__ == "__main__": unittest.main()