|
"""Here come the tests for implemented transform.""" |
|
import unittest |
|
|
|
import copy |
|
import yaml |
|
import math |
|
from argparse import Namespace |
|
from onmt.transforms import ( |
|
get_transforms_cls, |
|
get_specials, |
|
make_transforms, |
|
TransformPipe, |
|
) |
|
from onmt.transforms.bart import BARTNoising |
|
|
|
|
|
class TestTransform(unittest.TestCase): |
|
def test_transform_register(self): |
|
builtin_transform = [ |
|
"filtertoolong", |
|
"prefix", |
|
"sentencepiece", |
|
"bpe", |
|
"onmt_tokenize", |
|
"bart", |
|
"switchout", |
|
"tokendrop", |
|
"tokenmask", |
|
"insert_mask_before_placeholder", |
|
] |
|
get_transforms_cls(builtin_transform) |
|
|
|
def test_vocab_required_transform(self): |
|
transforms_cls = get_transforms_cls(["bart", "switchout"]) |
|
opt = Namespace(seed=-1, switchout_temperature=1.0) |
|
|
|
transforms = make_transforms(opt, transforms_cls, vocabs=None) |
|
self.assertEqual(len(transforms), 0) |
|
with self.assertRaises(ValueError): |
|
transforms_cls["switchout"](opt).warm_up(vocabs=None) |
|
transforms_cls["bart"](opt).warm_up(vocabs=None) |
|
|
|
def test_transform_specials(self): |
|
transforms_cls = get_transforms_cls(["prefix"]) |
|
corpora = yaml.safe_load( |
|
""" |
|
trainset: |
|
path_src: data/src-train.txt |
|
path_tgt: data/tgt-train.txt |
|
transforms: ["prefix"] |
|
weight: 1 |
|
src_prefix: "⦅_pf_src⦆" |
|
tgt_prefix: "⦅_pf_tgt⦆" |
|
""" |
|
) |
|
opt = Namespace(data=corpora) |
|
specials = get_specials(opt, transforms_cls) |
|
specials_expected = {"src": ["⦅_pf_src⦆"], "tgt": ["⦅_pf_tgt⦆"]} |
|
self.assertEqual(specials, specials_expected) |
|
|
|
def test_transform_pipe(self): |
|
|
|
prefix_cls = get_transforms_cls(["prefix"])["prefix"] |
|
corpora = yaml.safe_load( |
|
""" |
|
trainset: |
|
path_src: data/src-train.txt |
|
path_tgt: data/tgt-train.txt |
|
transforms: [prefix, filtertoolong] |
|
weight: 1 |
|
src_prefix: "⦅_pf_src⦆" |
|
tgt_prefix: "⦅_pf_tgt⦆" |
|
""" |
|
) |
|
opt = Namespace(data=corpora, seed=-1) |
|
prefix_transform = prefix_cls(opt) |
|
prefix_transform.warm_up() |
|
|
|
filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] |
|
opt = Namespace(src_seq_length=4, tgt_seq_length=4) |
|
filter_transform = filter_cls(opt) |
|
|
|
transform_pipe = TransformPipe.build_from([prefix_transform, filter_transform]) |
|
ex = { |
|
"src": ["Hello", ",", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
|
|
ex_after = transform_pipe.apply(copy.deepcopy(ex), corpus_name="trainset") |
|
|
|
self.assertIsNone(ex_after) |
|
|
|
self.assertTrue(len(transform_pipe.statistics.observables) > 0) |
|
msg = transform_pipe.statistics.report() |
|
self.assertIsNotNone(msg) |
|
|
|
self.assertTrue(len(transform_pipe.statistics.observables) == 0) |
|
|
|
|
|
class TestMiscTransform(unittest.TestCase): |
|
def test_prefix(self): |
|
prefix_cls = get_transforms_cls(["prefix"])["prefix"] |
|
corpora = yaml.safe_load( |
|
""" |
|
trainset: |
|
path_src: data/src-train.txt |
|
path_tgt: data/tgt-train.txt |
|
transforms: [prefix] |
|
weight: 1 |
|
src_prefix: "⦅_pf_src⦆" |
|
tgt_prefix: "⦅_pf_tgt⦆" |
|
""" |
|
) |
|
opt = Namespace(data=corpora, seed=-1) |
|
prefix_transform = prefix_cls(opt) |
|
prefix_transform.warm_up() |
|
self.assertIn("trainset", prefix_transform.prefix_dict) |
|
|
|
ex_in = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
with self.assertRaises(ValueError): |
|
prefix_transform.apply(ex_in) |
|
prefix_transform.apply(ex_in, corpus_name="validset") |
|
ex_out = prefix_transform.apply(ex_in, corpus_name="trainset") |
|
self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆") |
|
self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆") |
|
|
|
def test_filter_too_long(self): |
|
filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] |
|
opt = Namespace(src_seq_length=100, tgt_seq_length=100) |
|
filter_transform = filter_cls(opt) |
|
|
|
ex_in = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
ex_out = filter_transform.apply(ex_in) |
|
self.assertIs(ex_out, ex_in) |
|
filter_transform.tgt_seq_length = 2 |
|
ex_out = filter_transform.apply(ex_in) |
|
self.assertIsNone(ex_out) |
|
|
|
|
|
class TestSubwordTransform(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls): |
|
cls.base_opts = { |
|
"seed": 3431, |
|
"share_vocab": False, |
|
"src_subword_model": "data/sample.bpe", |
|
"tgt_subword_model": "data/sample.bpe", |
|
"src_subword_nbest": 1, |
|
"tgt_subword_nbest": 1, |
|
"src_subword_alpha": 0.0, |
|
"tgt_subword_alpha": 0.0, |
|
"src_subword_vocab": "", |
|
"tgt_subword_vocab": "", |
|
"src_vocab_threshold": 0, |
|
"tgt_vocab_threshold": 0, |
|
} |
|
|
|
def test_bpe(self): |
|
bpe_cls = get_transforms_cls(["bpe"])["bpe"] |
|
opt = Namespace(**self.base_opts) |
|
bpe_cls._validate_options(opt) |
|
bpe_transform = bpe_cls(opt) |
|
bpe_transform.warm_up() |
|
ex = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
bpe_transform.apply(ex, is_train=True) |
|
ex_gold = { |
|
"src": ["H@@", "ell@@", "o", "world", "."], |
|
"tgt": ["B@@", "on@@", "j@@", "our", "le", "mon@@", "de", "."], |
|
} |
|
self.assertEqual(ex, ex_gold) |
|
|
|
bpe_transform.dropout["src"] = 1.0 |
|
tokens = ["Another", "world", "."] |
|
gold_bpe = ["A@@", "no@@", "ther", "world", "."] |
|
gold_dropout = [ |
|
"A@@", |
|
"n@@", |
|
"o@@", |
|
"t@@", |
|
"h@@", |
|
"e@@", |
|
"r", |
|
"w@@", |
|
"o@@", |
|
"r@@", |
|
"l@@", |
|
"d", |
|
".", |
|
] |
|
|
|
after_bpe = bpe_transform._tokenize(tokens, is_train=False) |
|
self.assertEqual(after_bpe, gold_bpe) |
|
|
|
after_bpe = bpe_transform._tokenize(tokens, is_train=True) |
|
self.assertEqual(after_bpe, gold_dropout) |
|
|
|
|
|
|
|
after_bpe2 = bpe_transform._tokenize(tokens, is_train=False) |
|
self.assertEqual(after_bpe2, gold_dropout) |
|
|
|
def test_sentencepiece(self): |
|
sp_cls = get_transforms_cls(["sentencepiece"])["sentencepiece"] |
|
base_opt = copy.copy(self.base_opts) |
|
base_opt["src_subword_model"] = "data/sample.sp.model" |
|
base_opt["tgt_subword_model"] = "data/sample.sp.model" |
|
opt = Namespace(**base_opt) |
|
sp_cls._validate_options(opt) |
|
sp_transform = sp_cls(opt) |
|
sp_transform.warm_up() |
|
|
|
ex = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
sp_transform.apply(ex, is_train=True) |
|
ex_gold = { |
|
"src": ["▁H", "el", "lo", "▁world", "▁."], |
|
"tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."], |
|
} |
|
self.assertEqual(ex, ex_gold) |
|
|
|
|
|
sp_transform.src_subword_nbest = 4 |
|
tokens = ["Another", "world", "."] |
|
gold_sp = ["▁An", "other", "▁world", "▁."] |
|
|
|
after_sp = sp_transform._tokenize(tokens, is_train=True) |
|
self.assertEqual(after_sp, ["▁An", "o", "ther", "▁world", "▁."]) |
|
|
|
after_sp = sp_transform._tokenize(tokens, is_train=False) |
|
self.assertEqual(after_sp, gold_sp) |
|
|
|
|
|
ex = { |
|
"src": "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" |
|
"### Response : ⦅newline⦆⦅_mask_before_⦆response", |
|
"tgt": "", |
|
} |
|
ex["src"] = ex["src"].split(" ") |
|
ex_gold = { |
|
"src": [ |
|
"▁", |
|
"#", |
|
"#", |
|
"#", |
|
"▁In", |
|
"struct", |
|
"ion", |
|
":", |
|
"▁in", |
|
"struct", |
|
"ion", |
|
"▁", |
|
"#", |
|
"#", |
|
"#", |
|
"▁Re", |
|
"s", |
|
"p", |
|
"on", |
|
"s", |
|
"e", |
|
"▁", |
|
":", |
|
"<blank>", |
|
"▁re", |
|
"s", |
|
"p", |
|
"on", |
|
"s", |
|
"e", |
|
], |
|
"tgt": [], |
|
} |
|
sp_transform.apply(ex, is_train=True) |
|
self.assertEqual(ex, ex_gold) |
|
|
|
def test_pyonmttok_bpe(self): |
|
onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"] |
|
base_opt = copy.copy(self.base_opts) |
|
base_opt["src_subword_type"] = "bpe" |
|
base_opt["tgt_subword_type"] = "bpe" |
|
onmt_args = "{'mode': 'space', 'joiner_annotate': True}" |
|
base_opt["src_onmttok_kwargs"] = onmt_args |
|
base_opt["tgt_onmttok_kwargs"] = onmt_args |
|
base_opt["gpt2_pretok"] = False |
|
opt = Namespace(**base_opt) |
|
onmttok_cls._validate_options(opt) |
|
onmttok_transform = onmttok_cls(opt) |
|
onmttok_transform.warm_up() |
|
ex = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
onmttok_transform.apply(ex, is_train=True) |
|
ex_gold = { |
|
"src": ["H■", "ell■", "o", "world", "."], |
|
"tgt": ["B■", "on■", "j■", "our", "le", "mon■", "de", "."], |
|
} |
|
self.assertEqual(ex, ex_gold) |
|
|
|
|
|
ex = { |
|
"src": ( |
|
"### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" |
|
"### Response : ⦅newline⦆⦅_mask_before_⦆response" |
|
), |
|
"tgt": "", |
|
} |
|
ex["src"] = ex["src"].split(" ") |
|
ex_gold = { |
|
"src": [ |
|
"#■", |
|
"#■", |
|
"#", |
|
"In■", |
|
"struc■", |
|
"tion■", |
|
":", |
|
"\n■", |
|
"in■", |
|
"struc■", |
|
"tion■", |
|
"\n■", |
|
"\n■", |
|
"#■", |
|
"#■", |
|
"#", |
|
"R■", |
|
"es■", |
|
"p■", |
|
"on■", |
|
"se", |
|
":", |
|
"\n", |
|
"<blank>", |
|
"respon■", |
|
"se", |
|
], |
|
"tgt": [], |
|
} |
|
onmttok_transform.apply(ex, is_train=True) |
|
self.assertEqual(ex, ex_gold) |
|
|
|
def test_pyonmttok_sp(self): |
|
onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"] |
|
base_opt = copy.copy(self.base_opts) |
|
base_opt["src_subword_type"] = "sentencepiece" |
|
base_opt["tgt_subword_type"] = "sentencepiece" |
|
base_opt["src_subword_model"] = "data/sample.sp.model" |
|
base_opt["tgt_subword_model"] = "data/sample.sp.model" |
|
onmt_args = "{'mode': 'none', 'spacer_annotate': True}" |
|
base_opt["src_onmttok_kwargs"] = onmt_args |
|
base_opt["tgt_onmttok_kwargs"] = onmt_args |
|
base_opt["gpt2_pretok"] = False |
|
opt = Namespace(**base_opt) |
|
onmttok_cls._validate_options(opt) |
|
onmttok_transform = onmttok_cls(opt) |
|
onmttok_transform.warm_up() |
|
ex = { |
|
"src": ["Hello", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
onmttok_transform.apply(ex, is_train=True) |
|
ex_gold = { |
|
"src": ["▁H", "el", "lo", "▁world", "▁."], |
|
"tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."], |
|
} |
|
self.assertEqual(ex, ex_gold) |
|
|
|
|
|
ex = { |
|
"src": ( |
|
"### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" |
|
"### Response : ⦅newline⦆⦅_mask_before_⦆response" |
|
), |
|
"tgt": "", |
|
} |
|
ex["src"] = ex["src"].split(" ") |
|
onmttok_transform.apply(ex, is_train=True) |
|
ex_gold = { |
|
"src": [ |
|
"▁", |
|
"#", |
|
"#", |
|
"#", |
|
"▁In", |
|
"struct", |
|
"ion", |
|
":", |
|
"▁in", |
|
"struct", |
|
"ion", |
|
"▁", |
|
"#", |
|
"#", |
|
"#", |
|
"▁Re", |
|
"s", |
|
"p", |
|
"on", |
|
"se", |
|
"▁", |
|
":", |
|
"<blank>", |
|
"▁re", |
|
"s", |
|
"p", |
|
"on", |
|
"se", |
|
], |
|
"tgt": [], |
|
} |
|
self.assertEqual(ex, ex_gold) |
|
|
|
|
|
class TestSamplingTransform(unittest.TestCase): |
|
def test_tokendrop(self): |
|
tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"] |
|
opt = Namespace(seed=3434, tokendrop_temperature=0.1) |
|
tokendrop_transform = tokendrop_cls(opt) |
|
tokendrop_transform.warm_up() |
|
ex = { |
|
"src": ["Hello", ",", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
|
|
ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=False) |
|
self.assertEqual(ex_after, ex) |
|
|
|
ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=True) |
|
self.assertNotEqual(ex_after, ex) |
|
|
|
def test_tokenmask(self): |
|
tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"] |
|
opt = Namespace(seed=3434, tokenmask_temperature=0.1) |
|
tokenmask_transform = tokenmask_cls(opt) |
|
tokenmask_transform.warm_up() |
|
ex = { |
|
"src": ["Hello", ",", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
|
|
ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=False) |
|
self.assertEqual(ex_after, ex) |
|
|
|
ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=True) |
|
self.assertNotEqual(ex_after, ex) |
|
|
|
def test_switchout(self): |
|
switchout_cls = get_transforms_cls(["switchout"])["switchout"] |
|
opt = Namespace(seed=3434, switchout_temperature=0.1) |
|
switchout_transform = switchout_cls(opt) |
|
with self.assertRaises(ValueError): |
|
|
|
switchout_transform.warm_up(vocabs=None) |
|
vocabs = { |
|
"src": Namespace(ids_to_tokens=["A", "Fake", "vocab"]), |
|
"tgt": Namespace(ids_to_tokens=["A", "Fake", "vocab"]), |
|
} |
|
switchout_transform.warm_up(vocabs=vocabs) |
|
ex = { |
|
"src": ["Hello", ",", "world", "."], |
|
"tgt": ["Bonjour", "le", "monde", "."], |
|
} |
|
|
|
ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=False) |
|
self.assertEqual(ex_after, ex) |
|
|
|
ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=True) |
|
self.assertNotEqual(ex_after, ex) |
|
|
|
|
|
class TestBARTNoising(unittest.TestCase): |
|
def setUp(self): |
|
BARTNoising.set_random_seed(1234) |
|
self.MASK_TOK = "[MASK]" |
|
self.FAKE_VOCAB = "[TESTING]" |
|
|
|
def test_sentence_permute(self): |
|
sent1 = ["Hello", "world", "."] |
|
sent2 = ["Sentence", "1", "!"] |
|
sent3 = ["Sentence", "2", "!"] |
|
sent4 = ["Sentence", "3", "!"] |
|
|
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
permute_sent_ratio=0.5, |
|
replace_length=0, |
|
|
|
) |
|
tokens = sent1 + sent2 + sent3 + sent4 |
|
ends = bart_noise._get_sentence_borders(tokens).tolist() |
|
self.assertEqual(ends, [3, 6, 9, 12]) |
|
tokens_perm = bart_noise.apply(tokens) |
|
expected_tokens = sent2 + sent1 + sent3 + sent4 |
|
self.assertEqual(expected_tokens, tokens_perm) |
|
|
|
def test_rotate(self): |
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
rotate_ratio=1.0, |
|
replace_length=0, |
|
) |
|
tokens = ["This", "looks", "really", "good", "!"] |
|
rotated = bart_noise.apply(tokens) |
|
self.assertNotEqual(tokens, rotated) |
|
not_rotate = bart_noise.rolling_noise(tokens, p=0.0) |
|
self.assertEqual(tokens, not_rotate) |
|
|
|
def test_token_insert(self): |
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
mask_tok=self.MASK_TOK, |
|
insert_ratio=0.5, |
|
random_ratio=0.3, |
|
replace_length=0, |
|
|
|
) |
|
tokens = ["This", "looks", "really", "good", "!"] |
|
inserted = bart_noise.apply(tokens) |
|
n_insert = math.ceil(len(tokens) * bart_noise.insert_ratio) |
|
inserted_len = n_insert + len(tokens) |
|
self.assertEqual(len(inserted), inserted_len) |
|
|
|
n_random = math.ceil(n_insert * bart_noise.random_ratio) |
|
self.assertEqual( |
|
sum(1 if tok == self.FAKE_VOCAB else 0 for tok in inserted), |
|
n_random, |
|
) |
|
|
|
self.assertEqual( |
|
sum(1 if tok == self.MASK_TOK else 0 for tok in inserted), |
|
n_insert - n_random, |
|
) |
|
|
|
def test_token_mask(self): |
|
"""Mask will be done on token level. |
|
|
|
Condition: |
|
* `mask_length` == subword; |
|
* or not specify subword marker (joiner/spacer) by `is_joiner`. |
|
""" |
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
mask_tok=self.MASK_TOK, |
|
mask_ratio=0.5, |
|
mask_length="subword", |
|
replace_length=0, |
|
|
|
|
|
|
|
) |
|
tokens = ["H■", "ell■", "o", "world", "."] |
|
|
|
self.assertTrue(all(bart_noise._is_word_start(tokens))) |
|
n_tokens = len(tokens) |
|
|
|
|
|
masked = bart_noise.apply(tokens) |
|
n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) |
|
|
|
self.assertEqual(len(masked), n_tokens - n_masked) |
|
|
|
|
|
bart_noise.replace_length = 1 |
|
masked = bart_noise.apply(tokens) |
|
n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) |
|
|
|
self.assertEqual(len(masked), n_tokens) |
|
self.assertEqual( |
|
sum([1 if tok == self.MASK_TOK else 0 for tok in masked]), n_masked |
|
) |
|
|
|
def test_whole_word_mask(self): |
|
"""Mask will be done on whole word that may across multiply token. |
|
|
|
Condition: |
|
* `mask_length` == word; |
|
* specify subword marker in order to find word boundary. |
|
""" |
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
mask_tok=self.MASK_TOK, |
|
mask_ratio=0.5, |
|
mask_length="word", |
|
is_joiner=True, |
|
replace_length=0, |
|
|
|
|
|
|
|
) |
|
tokens = ["H■", "ell■", "o", "wor■", "ld", "."] |
|
|
|
token_starts = [True, False, False, True, False, True] |
|
self.assertEqual(bart_noise._is_word_start(tokens), token_starts) |
|
|
|
|
|
masked = bart_noise.apply(copy.copy(tokens)) |
|
n_words = sum(token_starts) |
|
n_masked = math.ceil(n_words * bart_noise.mask_ratio) |
|
|
|
|
|
|
|
|
|
bart_noise.replace_length = 1 |
|
masked = bart_noise.apply(copy.copy(tokens)) |
|
|
|
|
|
n_words = sum(token_starts) |
|
n_masked = math.ceil(n_words * bart_noise.mask_ratio) |
|
self.assertEqual( |
|
sum(1 if tok == self.MASK_TOK else 0 for tok in masked), n_masked |
|
) |
|
|
|
|
|
bart_noise.replace_length = -1 |
|
masked = bart_noise.apply(copy.copy(tokens)) |
|
|
|
self.assertEqual(len(masked), len(tokens)) |
|
n_words = sum(token_starts) |
|
n_masked = math.ceil(n_words * bart_noise.mask_ratio) |
|
|
|
|
|
self.assertTrue( |
|
sum(1 if tok == self.MASK_TOK else 0 for tok in masked) > n_masked |
|
) |
|
|
|
def test_span_infilling(self): |
|
bart_noise = BARTNoising( |
|
vocab=[self.FAKE_VOCAB], |
|
mask_tok=self.MASK_TOK, |
|
mask_ratio=0.5, |
|
mask_length="span-poisson", |
|
poisson_lambda=3.0, |
|
is_joiner=True, |
|
replace_length=1, |
|
|
|
|
|
|
|
) |
|
self.assertIsNotNone(bart_noise.mask_span_distribution) |
|
tokens = ["H■", "ell■", "o", "world", ".", "An■", "other", "!"] |
|
|
|
token_starts = [True, False, False, True, True, True, False, True] |
|
self.assertEqual(bart_noise._is_word_start(tokens), token_starts) |
|
bart_noise.apply(copy.copy(tokens)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFeaturesTransform(unittest.TestCase): |
|
def test_inferfeats(self): |
|
inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] |
|
opt = Namespace(reversible_tokenization="joiner") |
|
inferfeats_transform = inferfeats_cls(opt) |
|
|
|
ex_in = { |
|
"src": [ |
|
"however", |
|
"■,", |
|
"according", |
|
"to", |
|
"the", |
|
"logs", |
|
"■,", |
|
"she", |
|
"is", |
|
"hard", |
|
"■-■", |
|
"working", |
|
"■.", |
|
], |
|
"src_original": [ |
|
"however,", |
|
"according", |
|
"to", |
|
"the", |
|
"logs,", |
|
"she", |
|
"is", |
|
"hard-working.", |
|
], |
|
} |
|
ex_out = inferfeats_transform.apply(ex_in) |
|
self.assertIs(ex_out, ex_in) |
|
|
|
ex_in["src_feats"] = [["1", "2", "3", "4", "5", "6", "7", "8"]] |
|
ex_out = inferfeats_transform.apply(ex_in) |
|
self.assertEqual( |
|
ex_out["src_feats"][0], |
|
["1", "1", "2", "3", "4", "5", "5", "6", "7", "8", "8", "8", "8"], |
|
) |
|
|
|
ex_in["src"] = [ |
|
"⦅mrk_case_modifier_C⦆", |
|
"however", |
|
"■,", |
|
"according", |
|
"to", |
|
"the", |
|
"logs", |
|
"■,", |
|
"⦅mrk_begin_case_region_U⦆", |
|
"she", |
|
"is", |
|
"hard", |
|
"■-■", |
|
"working", |
|
"⦅mrk_end_case_region_U⦆", |
|
"■.", |
|
] |
|
ex_in["src_feats"] = [["1", "2", "3", "4", "5", "6", "7", "8"]] |
|
ex_out = inferfeats_transform.apply(ex_in) |
|
self.assertEqual( |
|
ex_out["src_feats"][0], |
|
[ |
|
"1", |
|
"1", |
|
"1", |
|
"2", |
|
"3", |
|
"4", |
|
"5", |
|
"5", |
|
"6", |
|
"6", |
|
"7", |
|
"8", |
|
"8", |
|
"8", |
|
"8", |
|
"8", |
|
], |
|
) |
|
|
|
ex_in = { |
|
"src": [ |
|
"however", |
|
"■,", |
|
"according", |
|
"to", |
|
"the", |
|
"logs", |
|
"■,", |
|
"she", |
|
"is", |
|
"hard", |
|
"■-■", |
|
"working", |
|
"■.", |
|
], |
|
"src_original": [ |
|
"however", |
|
"■,", |
|
"according", |
|
"to", |
|
"the", |
|
"logs", |
|
"■,", |
|
"she", |
|
"is", |
|
"hard-working", |
|
"■.", |
|
], |
|
"src_feats": [["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"]], |
|
} |
|
ex_out = inferfeats_transform.apply(ex_in) |
|
self.assertEqual( |
|
ex_out["src_feats"][0], |
|
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "10", "10", "11"], |
|
) |
|
|
|
|
|
class TestInsertMaskBeforePlaceholder(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls): |
|
cls.base_opts = { |
|
"response_pattern": "Response : ⦅newline⦆", |
|
} |
|
|
|
def test_insert_mask_before_placeholder(self): |
|
insert_mask_before_placeholder_cls = get_transforms_cls( |
|
["insert_mask_before_placeholder"] |
|
)["insert_mask_before_placeholder"] |
|
opt = Namespace(**self.base_opts) |
|
insert_mask_before_placeholder_transform = insert_mask_before_placeholder_cls( |
|
opt |
|
) |
|
ex_in = { |
|
"src": "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" |
|
"### Response : ⦅newline⦆response", |
|
"tgt": "", |
|
} |
|
ex_in["src"] = ex_in["src"].split(" ") |
|
ex_in["tgt"] = ex_in["src"] |
|
ex_out = insert_mask_before_placeholder_transform.apply(ex_in) |
|
ex_gold = { |
|
"src": [ |
|
"###", |
|
"Instruction:", |
|
"⦅newline⦆instruction⦅newline⦆⦅newline⦆###", |
|
"Response", |
|
":", |
|
"⦅newline⦆⦅_mask_before_⦆response", |
|
], |
|
"tgt": [ |
|
"###", |
|
"Instruction:", |
|
"⦅newline⦆instruction⦅newline⦆⦅newline⦆###", |
|
"Response", |
|
":", |
|
"⦅newline⦆⦅_mask_before_⦆response", |
|
], |
|
} |
|
self.assertEqual(ex_out, ex_gold) |
|
|