File size: 3,334 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import copy
import unittest
import glob
import os

from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.train_single import prepare_transforms_vocabs
from onmt.constants import CorpusName


SAVE_DATA_PREFIX = "data/test_data_prepare"


def get_default_opts():
    parser = ArgumentParser(description="data sample prepare")
    dynamic_prepare_opts(parser)

    default_opts = [
        "-config",
        "data/data.yaml",
        "-src_vocab",
        "data/vocab-train.src",
        "-tgt_vocab",
        "data/vocab-train.tgt",
    ]

    opt = parser.parse_known_args(default_opts)[0]
    # Inject some dummy training options that may needed when build fields
    opt.copy_attn = False
    ArgumentParser.validate_prepare_opts(opt)
    return opt


default_opts = get_default_opts()


class TestData(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestData, self).__init__(*args, **kwargs)
        self.opt = default_opts

    def dataset_build(self, opt):
        try:
            prepare_transforms_vocabs(opt, {})
        except SystemExit as err:
            print(err)
        except IOError as err:
            if opt.skip_empty_level != "error":
                raise err
            else:
                print(f"Catched IOError: {err}")
        finally:
            # Remove the generated *pt files.
            for pt in glob.glob(SAVE_DATA_PREFIX + "*.pt"):
                os.remove(pt)
            if self.opt.save_data:
                # Remove the generated data samples
                sample_path = os.path.join(
                    os.path.dirname(self.opt.save_data), CorpusName.SAMPLE
                )
                if os.path.exists(sample_path):
                    for f in glob.glob(sample_path + "/*"):
                        os.remove(f)
                    os.rmdir(sample_path)


def _add_test(param_setting, methodname):
    """
    Adds a Test to TestData according to settings

    Args:
        param_setting: list of tuples of (param, setting)
        methodname: name of the method that gets called
    """

    def test_method(self):
        if param_setting:
            opt = copy.deepcopy(self.opt)
            for param, setting in param_setting:
                setattr(opt, param, setting)
        else:
            opt = self.opt
        getattr(self, methodname)(opt)

    if param_setting:
        name = "test_" + methodname + "_" + "_".join(str(param_setting).split())
    else:
        name = "test_" + methodname + "_standard"
    setattr(TestData, name, test_method)
    test_method.__name__ = name


test_databuild = [
    [],
    [("src_vocab_size", 1), ("tgt_vocab_size", 1)],
    [("src_vocab_size", 10000), ("tgt_vocab_size", 10000)],
    [("src_seq_len", 1)],
    [("src_seq_len", 5000)],
    [("src_seq_length_trunc", 1)],
    [("src_seq_length_trunc", 5000)],
    [("tgt_seq_len", 1)],
    [("tgt_seq_len", 5000)],
    [("tgt_seq_length_trunc", 1)],
    [("tgt_seq_length_trunc", 5000)],
    [("copy_attn", True)],
    [("share_vocab", True)],
    [("n_sample", 30), ("save_data", SAVE_DATA_PREFIX)],
    [("n_sample", 30), ("save_data", SAVE_DATA_PREFIX), ("skip_empty_level", "error")],
]

for p in test_databuild:
    _add_test(p, "dataset_build")