File size: 3,334 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import unittest
import glob
import os
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.train_single import prepare_transforms_vocabs
from onmt.constants import CorpusName
SAVE_DATA_PREFIX = "data/test_data_prepare"
def get_default_opts():
parser = ArgumentParser(description="data sample prepare")
dynamic_prepare_opts(parser)
default_opts = [
"-config",
"data/data.yaml",
"-src_vocab",
"data/vocab-train.src",
"-tgt_vocab",
"data/vocab-train.tgt",
]
opt = parser.parse_known_args(default_opts)[0]
# Inject some dummy training options that may needed when build fields
opt.copy_attn = False
ArgumentParser.validate_prepare_opts(opt)
return opt
default_opts = get_default_opts()
class TestData(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestData, self).__init__(*args, **kwargs)
self.opt = default_opts
def dataset_build(self, opt):
try:
prepare_transforms_vocabs(opt, {})
except SystemExit as err:
print(err)
except IOError as err:
if opt.skip_empty_level != "error":
raise err
else:
print(f"Catched IOError: {err}")
finally:
# Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + "*.pt"):
os.remove(pt)
if self.opt.save_data:
# Remove the generated data samples
sample_path = os.path.join(
os.path.dirname(self.opt.save_data), CorpusName.SAMPLE
)
if os.path.exists(sample_path):
for f in glob.glob(sample_path + "/*"):
os.remove(f)
os.rmdir(sample_path)
def _add_test(param_setting, methodname):
"""
Adds a Test to TestData according to settings
Args:
param_setting: list of tuples of (param, setting)
methodname: name of the method that gets called
"""
def test_method(self):
if param_setting:
opt = copy.deepcopy(self.opt)
for param, setting in param_setting:
setattr(opt, param, setting)
else:
opt = self.opt
getattr(self, methodname)(opt)
if param_setting:
name = "test_" + methodname + "_" + "_".join(str(param_setting).split())
else:
name = "test_" + methodname + "_standard"
setattr(TestData, name, test_method)
test_method.__name__ = name
test_databuild = [
[],
[("src_vocab_size", 1), ("tgt_vocab_size", 1)],
[("src_vocab_size", 10000), ("tgt_vocab_size", 10000)],
[("src_seq_len", 1)],
[("src_seq_len", 5000)],
[("src_seq_length_trunc", 1)],
[("src_seq_length_trunc", 5000)],
[("tgt_seq_len", 1)],
[("tgt_seq_len", 5000)],
[("tgt_seq_length_trunc", 1)],
[("tgt_seq_length_trunc", 5000)],
[("copy_attn", True)],
[("share_vocab", True)],
[("n_sample", 30), ("save_data", SAVE_DATA_PREFIX)],
[("n_sample", 30), ("save_data", SAVE_DATA_PREFIX), ("skip_empty_level", "error")],
]
for p in test_databuild:
_add_test(p, "dataset_build")
|