Upload 313 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- e_smiles.py +0 -0
- infer.sh +10 -0
- inference.py +5 -0
- onmt/__init__.py +24 -0
- onmt/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/__pycache__/__init__.cpython-37.pyc +0 -0
- onmt/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/__pycache__/constants.cpython-311.pyc +0 -0
- onmt/__pycache__/constants.cpython-38.pyc +0 -0
- onmt/__pycache__/inference_engine.cpython-38.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-311.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-38.pyc +0 -0
- onmt/__pycache__/opts.cpython-311.pyc +0 -0
- onmt/__pycache__/opts.cpython-38.pyc +0 -0
- onmt/__pycache__/train_single.cpython-38.pyc +0 -0
- onmt/__pycache__/trainer.cpython-38.pyc +0 -0
- onmt/bin/__init__.py +0 -0
- onmt/bin/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/average_models.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/build_vocab.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/release_model.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/server.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/train.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-38.pyc +0 -0
- onmt/bin/average_models.py +60 -0
- onmt/bin/build_vocab.py +287 -0
- onmt/bin/release_model.py +39 -0
- onmt/bin/server.py +167 -0
- onmt/bin/train.py +71 -0
- onmt/bin/translate.py +60 -0
- onmt/constants.py +41 -0
- onmt/decoders/__init__.py +63 -0
- onmt/decoders/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-38.pyc +0 -0
- onmt/decoders/cnn_decoder.py +141 -0
- onmt/decoders/decoder.py +405 -0
- onmt/decoders/ensemble.py +150 -0
- onmt/decoders/transformer.py +835 -0
- onmt/encoders/__init__.py +67 -0
- onmt/encoders/__pycache__/__init__.cpython-311.pyc +0 -0
e_smiles.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
infer.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python inference.py \
|
2 |
+
--model trained_models/retrosnyhesis_ReactSeq_prompt_model_on_50k_aug100.pt \
|
3 |
+
--src ./tmp_data/src.txt \
|
4 |
+
--output ./tmp_data/tgt.txt \
|
5 |
+
--beam_size 10 \
|
6 |
+
--n_best 10 \
|
7 |
+
--batch_size 16384 \
|
8 |
+
--batch_type tokens \
|
9 |
+
--max_length 500 \
|
10 |
+
--seed 0
|
inference.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
from onmt.bin.translate import main
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
main()
|
onmt/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Main entry point of the ONMT library """
|
2 |
+
import onmt.inputters
|
3 |
+
import onmt.encoders
|
4 |
+
import onmt.decoders
|
5 |
+
import onmt.models
|
6 |
+
import onmt.utils
|
7 |
+
import onmt.modules
|
8 |
+
import sys
|
9 |
+
import onmt.utils.optimizers
|
10 |
+
|
11 |
+
onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer
|
12 |
+
sys.modules["onmt.Optim"] = onmt.utils.optimizers
|
13 |
+
|
14 |
+
# For Flake
|
15 |
+
__all__ = [
|
16 |
+
onmt.inputters,
|
17 |
+
onmt.encoders,
|
18 |
+
onmt.decoders,
|
19 |
+
onmt.models,
|
20 |
+
onmt.utils,
|
21 |
+
onmt.modules,
|
22 |
+
]
|
23 |
+
|
24 |
+
__version__ = "3.4.1"
|
onmt/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (892 Bytes). View file
|
|
onmt/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (605 Bytes). View file
|
|
onmt/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (603 Bytes). View file
|
|
onmt/__pycache__/constants.cpython-311.pyc
ADDED
Binary file (2.06 kB). View file
|
|
onmt/__pycache__/constants.cpython-38.pyc
ADDED
Binary file (1.61 kB). View file
|
|
onmt/__pycache__/inference_engine.cpython-38.pyc
ADDED
Binary file (3.22 kB). View file
|
|
onmt/__pycache__/model_builder.cpython-311.pyc
ADDED
Binary file (19.4 kB). View file
|
|
onmt/__pycache__/model_builder.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
onmt/__pycache__/opts.cpython-311.pyc
ADDED
Binary file (58 kB). View file
|
|
onmt/__pycache__/opts.cpython-38.pyc
ADDED
Binary file (38.4 kB). View file
|
|
onmt/__pycache__/train_single.cpython-38.pyc
ADDED
Binary file (6.41 kB). View file
|
|
onmt/__pycache__/trainer.cpython-38.pyc
ADDED
Binary file (14.5 kB). View file
|
|
onmt/bin/__init__.py
ADDED
File without changes
|
onmt/bin/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (171 Bytes). View file
|
|
onmt/bin/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (145 Bytes). View file
|
|
onmt/bin/__pycache__/average_models.cpython-38.pyc
ADDED
Binary file (1.48 kB). View file
|
|
onmt/bin/__pycache__/build_vocab.cpython-38.pyc
ADDED
Binary file (8.77 kB). View file
|
|
onmt/bin/__pycache__/release_model.cpython-38.pyc
ADDED
Binary file (1.17 kB). View file
|
|
onmt/bin/__pycache__/server.cpython-38.pyc
ADDED
Binary file (5.08 kB). View file
|
|
onmt/bin/__pycache__/train.cpython-38.pyc
ADDED
Binary file (1.84 kB). View file
|
|
onmt/bin/__pycache__/translate.cpython-311.pyc
ADDED
Binary file (2.89 kB). View file
|
|
onmt/bin/__pycache__/translate.cpython-38.pyc
ADDED
Binary file (1.77 kB). View file
|
|
onmt/bin/average_models.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def average_models(model_files, fp32=False):
|
7 |
+
vocab = None
|
8 |
+
opt = None
|
9 |
+
avg_model = None
|
10 |
+
avg_generator = None
|
11 |
+
|
12 |
+
for i, model_file in enumerate(model_files):
|
13 |
+
m = torch.load(model_file, map_location="cpu")
|
14 |
+
model_weights = m["model"]
|
15 |
+
generator_weights = m["generator"]
|
16 |
+
|
17 |
+
if fp32:
|
18 |
+
for k, v in model_weights.items():
|
19 |
+
model_weights[k] = v.float()
|
20 |
+
for k, v in generator_weights.items():
|
21 |
+
generator_weights[k] = v.float()
|
22 |
+
|
23 |
+
if i == 0:
|
24 |
+
vocab, opt = m["vocab"], m["opt"]
|
25 |
+
avg_model = model_weights
|
26 |
+
avg_generator = generator_weights
|
27 |
+
else:
|
28 |
+
for k, v in avg_model.items():
|
29 |
+
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
|
30 |
+
|
31 |
+
for k, v in avg_generator.items():
|
32 |
+
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
|
33 |
+
|
34 |
+
final = {
|
35 |
+
"vocab": vocab,
|
36 |
+
"opt": opt,
|
37 |
+
"optim": None,
|
38 |
+
"generator": avg_generator,
|
39 |
+
"model": avg_model,
|
40 |
+
}
|
41 |
+
return final
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
parser = argparse.ArgumentParser(description="")
|
46 |
+
parser.add_argument(
|
47 |
+
"-models", "-m", nargs="+", required=True, help="List of models"
|
48 |
+
)
|
49 |
+
parser.add_argument("-output", "-o", required=True, help="Output file")
|
50 |
+
parser.add_argument(
|
51 |
+
"-fp32", "-f", action="store_true", help="Cast params to float32"
|
52 |
+
)
|
53 |
+
opt = parser.parse_args()
|
54 |
+
|
55 |
+
final = average_models(opt.models, opt.fp32)
|
56 |
+
torch.save(final, opt.output)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
onmt/bin/build_vocab.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""Get vocabulary coutings from transformed corpora samples."""
|
3 |
+
import os
|
4 |
+
import copy
|
5 |
+
import multiprocessing as mp
|
6 |
+
import pyonmttok
|
7 |
+
from functools import partial
|
8 |
+
from onmt.utils.logging import init_logger, logger
|
9 |
+
from onmt.utils.misc import set_random_seed, check_path
|
10 |
+
from onmt.utils.parse import ArgumentParser
|
11 |
+
from onmt.opts import dynamic_prepare_opts
|
12 |
+
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
|
13 |
+
from onmt.inputters.text_utils import process, append_features_to_text
|
14 |
+
from onmt.transforms import make_transforms, get_transforms_cls
|
15 |
+
from onmt.constants import CorpusName, CorpusTask
|
16 |
+
from collections import Counter
|
17 |
+
|
18 |
+
|
19 |
+
MAXBUCKETSIZE = 256000
|
20 |
+
|
21 |
+
|
22 |
+
def write_files_from_queues(sample_path, queues):
|
23 |
+
"""
|
24 |
+
Standalone process that reads data from
|
25 |
+
queues in order and write to sample files.
|
26 |
+
"""
|
27 |
+
os.makedirs(sample_path, exist_ok=True)
|
28 |
+
for c_name in queues.keys():
|
29 |
+
dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
|
30 |
+
with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open(
|
31 |
+
dest_base + ".tgt", "w", encoding="utf-8"
|
32 |
+
) as f_tgt:
|
33 |
+
while True:
|
34 |
+
_next = False
|
35 |
+
for q in queues[c_name]:
|
36 |
+
item = q.get()
|
37 |
+
if item == "blank":
|
38 |
+
continue
|
39 |
+
if item == "break":
|
40 |
+
_next = True
|
41 |
+
break
|
42 |
+
_, src_line, tgt_line = item
|
43 |
+
f_src.write(src_line + "\n")
|
44 |
+
f_tgt.write(tgt_line + "\n")
|
45 |
+
if _next:
|
46 |
+
break
|
47 |
+
|
48 |
+
|
49 |
+
def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
|
50 |
+
"""Build vocab on (strided) subpart of the data."""
|
51 |
+
sub_counter_src = Counter()
|
52 |
+
sub_counter_tgt = Counter()
|
53 |
+
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
|
54 |
+
datasets_iterables = build_corpora_iters(
|
55 |
+
corpora,
|
56 |
+
transforms,
|
57 |
+
opts.data,
|
58 |
+
skip_empty_level=opts.skip_empty_level,
|
59 |
+
stride=stride,
|
60 |
+
offset=offset,
|
61 |
+
)
|
62 |
+
for c_name, c_iter in datasets_iterables.items():
|
63 |
+
for i, item in enumerate(c_iter):
|
64 |
+
maybe_example = process(CorpusTask.TRAIN, [item])
|
65 |
+
if maybe_example is not None:
|
66 |
+
maybe_example = maybe_example[0]
|
67 |
+
else:
|
68 |
+
if opts.dump_samples:
|
69 |
+
build_sub_vocab.queues[c_name][offset].put("blank")
|
70 |
+
continue
|
71 |
+
src_line, tgt_line = (
|
72 |
+
maybe_example["src"]["src"],
|
73 |
+
maybe_example["tgt"]["tgt"],
|
74 |
+
)
|
75 |
+
sub_counter_src.update(src_line.split(" "))
|
76 |
+
sub_counter_tgt.update(tgt_line.split(" "))
|
77 |
+
|
78 |
+
if "feats" in maybe_example["src"]:
|
79 |
+
src_feats_lines = maybe_example["src"]["feats"]
|
80 |
+
for k in range(opts.n_src_feats):
|
81 |
+
sub_counter_src_feats[k].update(src_feats_lines[k].split(" "))
|
82 |
+
else:
|
83 |
+
src_feats_lines = []
|
84 |
+
|
85 |
+
if opts.dump_samples:
|
86 |
+
src_pretty_line = append_features_to_text(src_line, src_feats_lines)
|
87 |
+
build_sub_vocab.queues[c_name][offset].put(
|
88 |
+
(i, src_pretty_line, tgt_line)
|
89 |
+
)
|
90 |
+
if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample:
|
91 |
+
if opts.dump_samples:
|
92 |
+
build_sub_vocab.queues[c_name][offset].put("break")
|
93 |
+
break
|
94 |
+
if opts.dump_samples:
|
95 |
+
build_sub_vocab.queues[c_name][offset].put("break")
|
96 |
+
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
|
97 |
+
|
98 |
+
|
99 |
+
def init_pool(queues):
|
100 |
+
"""Add the queues as attribute of the pooled function."""
|
101 |
+
build_sub_vocab.queues = queues
|
102 |
+
|
103 |
+
|
104 |
+
def build_vocab(opts, transforms, n_sample=3):
|
105 |
+
"""Build vocabulary from data."""
|
106 |
+
|
107 |
+
if n_sample == -1:
|
108 |
+
logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
|
109 |
+
elif n_sample > 0:
|
110 |
+
logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
|
111 |
+
else:
|
112 |
+
raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")
|
113 |
+
|
114 |
+
if opts.dump_samples:
|
115 |
+
logger.info(
|
116 |
+
"The samples on which the vocab is built will be "
|
117 |
+
"dumped to disk. It may slow down the process."
|
118 |
+
)
|
119 |
+
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
|
120 |
+
counter_src = Counter()
|
121 |
+
counter_tgt = Counter()
|
122 |
+
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
|
123 |
+
|
124 |
+
queues = {
|
125 |
+
c_name: [
|
126 |
+
mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)
|
127 |
+
]
|
128 |
+
for c_name in corpora.keys()
|
129 |
+
}
|
130 |
+
sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE)
|
131 |
+
if opts.dump_samples:
|
132 |
+
write_process = mp.Process(
|
133 |
+
target=write_files_from_queues, args=(sample_path, queues), daemon=True
|
134 |
+
)
|
135 |
+
write_process.start()
|
136 |
+
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
|
137 |
+
func = partial(
|
138 |
+
build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads
|
139 |
+
)
|
140 |
+
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
|
141 |
+
func, range(0, opts.num_threads)
|
142 |
+
):
|
143 |
+
counter_src.update(sub_counter_src)
|
144 |
+
counter_tgt.update(sub_counter_tgt)
|
145 |
+
for i in range(opts.n_src_feats):
|
146 |
+
counter_src_feats[i].update(sub_counter_src_feats[i])
|
147 |
+
if opts.dump_samples:
|
148 |
+
write_process.join()
|
149 |
+
return counter_src, counter_tgt, counter_src_feats
|
150 |
+
|
151 |
+
|
152 |
+
def ingest_tokens(opts, transforms, n_sample, learner, stride, offset):
|
153 |
+
def _mp_ingest(data):
|
154 |
+
func = partial(process, CorpusName.TRAIN)
|
155 |
+
chunk = len(data) // opts.num_threads
|
156 |
+
with mp.Pool(opts.num_threads) as pool:
|
157 |
+
buckets = pool.map(
|
158 |
+
func,
|
159 |
+
[data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)],
|
160 |
+
)
|
161 |
+
for bucket in buckets:
|
162 |
+
for ex in bucket:
|
163 |
+
if ex is not None:
|
164 |
+
src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"])
|
165 |
+
learner.ingest(src_line)
|
166 |
+
learner.ingest(tgt_line)
|
167 |
+
|
168 |
+
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
|
169 |
+
datasets_iterables = build_corpora_iters(
|
170 |
+
corpora,
|
171 |
+
transforms,
|
172 |
+
opts.data,
|
173 |
+
skip_empty_level=opts.skip_empty_level,
|
174 |
+
stride=stride,
|
175 |
+
offset=offset,
|
176 |
+
)
|
177 |
+
to_ingest = []
|
178 |
+
for c_name, c_iter in datasets_iterables.items():
|
179 |
+
for i, item in enumerate(c_iter):
|
180 |
+
if n_sample >= 0 and i >= n_sample:
|
181 |
+
break
|
182 |
+
if len(to_ingest) >= MAXBUCKETSIZE:
|
183 |
+
_mp_ingest(to_ingest)
|
184 |
+
to_ingest = []
|
185 |
+
to_ingest.append(item)
|
186 |
+
_mp_ingest(to_ingest)
|
187 |
+
|
188 |
+
|
189 |
+
def make_learner(tokenization_type, symbols):
|
190 |
+
if tokenization_type == "bpe":
|
191 |
+
# BPE training
|
192 |
+
learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols)
|
193 |
+
elif tokenization_type == "sentencepiece":
|
194 |
+
# SentencePiece training
|
195 |
+
learner = pyonmttok.SentencePieceLearner(
|
196 |
+
vocab_size=symbols, character_coverage=0.98
|
197 |
+
)
|
198 |
+
return learner
|
199 |
+
|
200 |
+
|
201 |
+
def build_vocab_main(opts):
|
202 |
+
"""Apply transforms to samples of specified data and build vocab from it.
|
203 |
+
|
204 |
+
Transforms that need vocab will be disabled in this.
|
205 |
+
Built vocab is saved in plain text format as following and can be pass as
|
206 |
+
`-src_vocab` (and `-tgt_vocab`) when training:
|
207 |
+
```
|
208 |
+
<tok_0>\t<count_0>
|
209 |
+
<tok_1>\t<count_1>
|
210 |
+
```
|
211 |
+
"""
|
212 |
+
|
213 |
+
ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
|
214 |
+
assert (
|
215 |
+
opts.n_sample == -1 or opts.n_sample > 1
|
216 |
+
), f"Illegal argument n_sample={opts.n_sample}."
|
217 |
+
|
218 |
+
logger = init_logger()
|
219 |
+
set_random_seed(opts.seed, False)
|
220 |
+
transforms_cls = get_transforms_cls(opts._all_transform)
|
221 |
+
|
222 |
+
if opts.learn_subwords:
|
223 |
+
logger.info(f"Ingesting {opts.src_subword_type} model from corpus")
|
224 |
+
learner = make_learner(opts.src_subword_type, opts.learn_subwords_size)
|
225 |
+
if opts.src_subword_model is not None:
|
226 |
+
tok_path = opts.src_subword_model
|
227 |
+
else:
|
228 |
+
data_dir = os.path.split(opts.save_data)[0]
|
229 |
+
if not os.path.exists(data_dir):
|
230 |
+
os.makedirs(data_dir)
|
231 |
+
tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model")
|
232 |
+
save_opts = copy.deepcopy(opts)
|
233 |
+
opts.src_subword_type = "none"
|
234 |
+
opts.tgt_subword_type = "none"
|
235 |
+
opts.src_onmttok_kwargs["joiner_annotate"] = False
|
236 |
+
opts.tgt_onmttok_kwargs["joiner_annotate"] = False
|
237 |
+
transforms = make_transforms(opts, transforms_cls, None)
|
238 |
+
ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0)
|
239 |
+
logger.info(f"Learning {tok_path} model, patience")
|
240 |
+
learner.learn(tok_path)
|
241 |
+
opts = save_opts
|
242 |
+
|
243 |
+
transforms = make_transforms(opts, transforms_cls, None)
|
244 |
+
|
245 |
+
logger.info(f"Counter vocab from {opts.n_sample} samples.")
|
246 |
+
src_counter, tgt_counter, src_feats_counter = build_vocab(
|
247 |
+
opts, transforms, n_sample=opts.n_sample
|
248 |
+
)
|
249 |
+
|
250 |
+
logger.info(f"Counters src: {len(src_counter)}")
|
251 |
+
logger.info(f"Counters tgt: {len(tgt_counter)}")
|
252 |
+
for i, feat_counter in enumerate(src_feats_counter):
|
253 |
+
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
|
254 |
+
|
255 |
+
def save_counter(counter, save_path):
|
256 |
+
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
|
257 |
+
with open(save_path, "w", encoding="utf8") as fo:
|
258 |
+
for tok, count in counter.most_common():
|
259 |
+
fo.write(tok + "\t" + str(count) + "\n")
|
260 |
+
|
261 |
+
if opts.share_vocab:
|
262 |
+
src_counter += tgt_counter
|
263 |
+
tgt_counter = src_counter
|
264 |
+
logger.info(f"Counters after share:{len(src_counter)}")
|
265 |
+
save_counter(src_counter, opts.src_vocab)
|
266 |
+
else:
|
267 |
+
save_counter(src_counter, opts.src_vocab)
|
268 |
+
save_counter(tgt_counter, opts.tgt_vocab)
|
269 |
+
|
270 |
+
for i, c in enumerate(src_feats_counter):
|
271 |
+
save_counter(c, f"{opts.src_vocab}_feat{i}")
|
272 |
+
|
273 |
+
|
274 |
+
def _get_parser():
|
275 |
+
parser = ArgumentParser(description="build_vocab.py")
|
276 |
+
dynamic_prepare_opts(parser, build_vocab_only=True)
|
277 |
+
return parser
|
278 |
+
|
279 |
+
|
280 |
+
def main():
|
281 |
+
parser = _get_parser()
|
282 |
+
opts, unknown = parser.parse_known_args()
|
283 |
+
build_vocab_main(opts)
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
main()
|
onmt/bin/release_model.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
parser = argparse.ArgumentParser(
|
8 |
+
description="Release an OpenNMT-py model for inference"
|
9 |
+
)
|
10 |
+
parser.add_argument("--model", "-m", help="The model path", required=True)
|
11 |
+
parser.add_argument("--output", "-o", help="The output path", required=True)
|
12 |
+
parser.add_argument(
|
13 |
+
"--format",
|
14 |
+
choices=["pytorch", "ctranslate2"],
|
15 |
+
default="pytorch",
|
16 |
+
help="The format of the released model",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--quantization",
|
20 |
+
"-q",
|
21 |
+
choices=["int8", "int16", "float16", "int8_float16"],
|
22 |
+
default=None,
|
23 |
+
help="Quantization type for CT2 model.",
|
24 |
+
)
|
25 |
+
opt = parser.parse_args()
|
26 |
+
|
27 |
+
model = torch.load(opt.model, map_location=torch.device("cpu"))
|
28 |
+
if opt.format == "pytorch":
|
29 |
+
model["optim"] = None
|
30 |
+
torch.save(model, opt.output)
|
31 |
+
elif opt.format == "ctranslate2":
|
32 |
+
import ctranslate2
|
33 |
+
|
34 |
+
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
|
35 |
+
converter.convert(opt.output, force=True, quantization=opt.quantization)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
main()
|
onmt/bin/server.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import configargparse
|
3 |
+
|
4 |
+
from flask import Flask, jsonify, request
|
5 |
+
from waitress import serve
|
6 |
+
from onmt.translate import TranslationServer, ServerModelError
|
7 |
+
import logging
|
8 |
+
from logging.handlers import RotatingFileHandler
|
9 |
+
|
10 |
+
STATUS_OK = "ok"
|
11 |
+
STATUS_ERROR = "error"
|
12 |
+
|
13 |
+
|
14 |
+
def start(config_file, url_root="./translator", host="0.0.0.0", port=5000, debug=False):
|
15 |
+
def prefix_route(route_function, prefix="", mask="{0}{1}"):
|
16 |
+
def newroute(route, *args, **kwargs):
|
17 |
+
return route_function(mask.format(prefix, route), *args, **kwargs)
|
18 |
+
|
19 |
+
return newroute
|
20 |
+
|
21 |
+
if debug:
|
22 |
+
logger = logging.getLogger("main")
|
23 |
+
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
24 |
+
file_handler = RotatingFileHandler(
|
25 |
+
"debug_requests.log", maxBytes=1000000, backupCount=10
|
26 |
+
)
|
27 |
+
file_handler.setFormatter(log_format)
|
28 |
+
logger.addHandler(file_handler)
|
29 |
+
|
30 |
+
app = Flask(__name__)
|
31 |
+
app.route = prefix_route(app.route, url_root)
|
32 |
+
translation_server = TranslationServer()
|
33 |
+
translation_server.start(config_file)
|
34 |
+
|
35 |
+
@app.route("/models", methods=["GET"])
|
36 |
+
def get_models():
|
37 |
+
out = translation_server.list_models()
|
38 |
+
return jsonify(out)
|
39 |
+
|
40 |
+
@app.route("/health", methods=["GET"])
|
41 |
+
def health():
|
42 |
+
out = {}
|
43 |
+
out["status"] = STATUS_OK
|
44 |
+
return jsonify(out)
|
45 |
+
|
46 |
+
@app.route("/clone_model/<int:model_id>", methods=["POST"])
|
47 |
+
def clone_model(model_id):
|
48 |
+
out = {}
|
49 |
+
data = request.get_json(force=True)
|
50 |
+
timeout = -1
|
51 |
+
if "timeout" in data:
|
52 |
+
timeout = data["timeout"]
|
53 |
+
del data["timeout"]
|
54 |
+
|
55 |
+
opt = data.get("opt", None)
|
56 |
+
try:
|
57 |
+
model_id, load_time = translation_server.clone_model(model_id, opt, timeout)
|
58 |
+
except ServerModelError as e:
|
59 |
+
out["status"] = STATUS_ERROR
|
60 |
+
out["error"] = str(e)
|
61 |
+
else:
|
62 |
+
out["status"] = STATUS_OK
|
63 |
+
out["model_id"] = model_id
|
64 |
+
out["load_time"] = load_time
|
65 |
+
|
66 |
+
return jsonify(out)
|
67 |
+
|
68 |
+
@app.route("/unload_model/<int:model_id>", methods=["GET"])
|
69 |
+
def unload_model(model_id):
|
70 |
+
out = {"model_id": model_id}
|
71 |
+
|
72 |
+
try:
|
73 |
+
translation_server.unload_model(model_id)
|
74 |
+
out["status"] = STATUS_OK
|
75 |
+
except Exception as e:
|
76 |
+
out["status"] = STATUS_ERROR
|
77 |
+
out["error"] = str(e)
|
78 |
+
|
79 |
+
return jsonify(out)
|
80 |
+
|
81 |
+
@app.route("/translate", methods=["POST"])
|
82 |
+
def translate():
|
83 |
+
inputs = request.get_json(force=True)
|
84 |
+
if debug:
|
85 |
+
logger.info(inputs)
|
86 |
+
out = {}
|
87 |
+
try:
|
88 |
+
trans, scores, n_best, _, aligns, align_scores = translation_server.run(
|
89 |
+
inputs
|
90 |
+
)
|
91 |
+
assert len(trans) == len(inputs) * n_best
|
92 |
+
assert len(scores) == len(inputs) * n_best
|
93 |
+
assert len(aligns) == len(inputs) * n_best
|
94 |
+
|
95 |
+
out = [[] for _ in range(n_best)]
|
96 |
+
for i in range(len(trans)):
|
97 |
+
response = {
|
98 |
+
"src": inputs[i // n_best]["src"],
|
99 |
+
"tgt": trans[i],
|
100 |
+
"n_best": n_best,
|
101 |
+
"pred_score": scores[i],
|
102 |
+
}
|
103 |
+
if len(aligns[i]) > 0 and aligns[i][0] is not None:
|
104 |
+
response["align"] = aligns[i]
|
105 |
+
response["align_score"] = align_scores[i]
|
106 |
+
out[i % n_best].append(response)
|
107 |
+
except ServerModelError as e:
|
108 |
+
model_id = inputs[0].get("id")
|
109 |
+
if debug:
|
110 |
+
logger.warning(
|
111 |
+
"Unload model #{} " "because of an error".format(model_id)
|
112 |
+
)
|
113 |
+
translation_server.models[model_id].unload()
|
114 |
+
out["error"] = str(e)
|
115 |
+
out["status"] = STATUS_ERROR
|
116 |
+
if debug:
|
117 |
+
logger.info(out)
|
118 |
+
return jsonify(out)
|
119 |
+
|
120 |
+
@app.route("/to_cpu/<int:model_id>", methods=["GET"])
|
121 |
+
def to_cpu(model_id):
|
122 |
+
out = {"model_id": model_id}
|
123 |
+
translation_server.models[model_id].to_cpu()
|
124 |
+
|
125 |
+
out["status"] = STATUS_OK
|
126 |
+
return jsonify(out)
|
127 |
+
|
128 |
+
@app.route("/to_gpu/<int:model_id>", methods=["GET"])
|
129 |
+
def to_gpu(model_id):
|
130 |
+
out = {"model_id": model_id}
|
131 |
+
translation_server.models[model_id].to_gpu()
|
132 |
+
|
133 |
+
out["status"] = STATUS_OK
|
134 |
+
return jsonify(out)
|
135 |
+
|
136 |
+
serve(app, host=host, port=port)
|
137 |
+
|
138 |
+
|
139 |
+
def _get_parser():
|
140 |
+
parser = configargparse.ArgumentParser(
|
141 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
142 |
+
description="OpenNMT-py REST Server",
|
143 |
+
)
|
144 |
+
parser.add_argument("--ip", type=str, default="0.0.0.0")
|
145 |
+
parser.add_argument("--port", type=int, default="5000")
|
146 |
+
parser.add_argument("--url_root", type=str, default="/translator")
|
147 |
+
parser.add_argument("--debug", "-d", action="store_true")
|
148 |
+
parser.add_argument(
|
149 |
+
"--config", "-c", type=str, default="./available_models/conf.json"
|
150 |
+
)
|
151 |
+
return parser
|
152 |
+
|
153 |
+
|
154 |
+
def main():
|
155 |
+
parser = _get_parser()
|
156 |
+
args = parser.parse_args()
|
157 |
+
start(
|
158 |
+
args.config,
|
159 |
+
url_root=args.url_root,
|
160 |
+
host=args.ip,
|
161 |
+
port=args.port,
|
162 |
+
debug=args.debug,
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
main()
|
onmt/bin/train.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""Train models with dynamic data."""
|
3 |
+
import torch
|
4 |
+
from functools import partial
|
5 |
+
from onmt.utils.distributed import ErrorHandler, spawned_train
|
6 |
+
from onmt.utils.misc import set_random_seed
|
7 |
+
from onmt.utils.logging import init_logger, logger
|
8 |
+
from onmt.utils.parse import ArgumentParser
|
9 |
+
from onmt.opts import train_opts
|
10 |
+
from onmt.train_single import main as single_main
|
11 |
+
|
12 |
+
|
13 |
+
# Set sharing strategy manually instead of default based on the OS.
|
14 |
+
# torch.multiprocessing.set_sharing_strategy('file_system')
|
15 |
+
|
16 |
+
|
17 |
+
def train(opt):
|
18 |
+
init_logger(opt.log_file)
|
19 |
+
|
20 |
+
ArgumentParser.validate_train_opts(opt)
|
21 |
+
ArgumentParser.update_model_opts(opt)
|
22 |
+
ArgumentParser.validate_model_opts(opt)
|
23 |
+
|
24 |
+
set_random_seed(opt.seed, False)
|
25 |
+
|
26 |
+
train_process = partial(single_main)
|
27 |
+
|
28 |
+
nb_gpu = len(opt.gpu_ranks)
|
29 |
+
|
30 |
+
if opt.world_size > 1:
|
31 |
+
mp = torch.multiprocessing.get_context("spawn")
|
32 |
+
# Create a thread to listen for errors in the child processes.
|
33 |
+
error_queue = mp.SimpleQueue()
|
34 |
+
error_handler = ErrorHandler(error_queue)
|
35 |
+
# Train with multiprocessing.
|
36 |
+
procs = []
|
37 |
+
for device_id in range(nb_gpu):
|
38 |
+
procs.append(
|
39 |
+
mp.Process(
|
40 |
+
target=spawned_train,
|
41 |
+
args=(train_process, opt, device_id, error_queue),
|
42 |
+
daemon=False,
|
43 |
+
)
|
44 |
+
)
|
45 |
+
procs[device_id].start()
|
46 |
+
logger.info(" Starting process pid: %d " % procs[device_id].pid)
|
47 |
+
error_handler.add_child(procs[device_id].pid)
|
48 |
+
for p in procs:
|
49 |
+
p.join()
|
50 |
+
|
51 |
+
elif nb_gpu == 1: # case 1 GPU only
|
52 |
+
train_process(opt, device_id=0)
|
53 |
+
else: # case only CPU
|
54 |
+
train_process(opt, device_id=-1)
|
55 |
+
|
56 |
+
|
57 |
+
def _get_parser():
|
58 |
+
parser = ArgumentParser(description="train.py")
|
59 |
+
train_opts(parser)
|
60 |
+
return parser
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
parser = _get_parser()
|
65 |
+
|
66 |
+
opt, unknown = parser.parse_known_args()
|
67 |
+
train(opt)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main()
|
onmt/bin/translate.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from onmt.utils.logging import init_logger
|
4 |
+
from onmt.translate.translator import build_translator
|
5 |
+
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
|
6 |
+
from onmt.inputters.inputter import IterOnDevice
|
7 |
+
from onmt.transforms import get_transforms_cls
|
8 |
+
from onmt.constants import CorpusTask
|
9 |
+
import onmt.opts as opts
|
10 |
+
from onmt.utils.parse import ArgumentParser
|
11 |
+
from onmt.utils.misc import use_gpu, set_random_seed
|
12 |
+
|
13 |
+
|
14 |
+
def translate(opt):
|
15 |
+
ArgumentParser.validate_translate_opts(opt)
|
16 |
+
ArgumentParser._get_all_transform_translate(opt)
|
17 |
+
ArgumentParser._validate_transforms_opts(opt)
|
18 |
+
ArgumentParser.validate_translate_opts_dynamic(opt)
|
19 |
+
logger = init_logger(opt.log_file)
|
20 |
+
|
21 |
+
set_random_seed(opt.seed, use_gpu(opt))
|
22 |
+
|
23 |
+
translator = build_translator(opt, logger=logger, report_score=True)
|
24 |
+
|
25 |
+
transforms_cls = get_transforms_cls(opt._all_transform)
|
26 |
+
|
27 |
+
infer_iter = build_dynamic_dataset_iter(
|
28 |
+
opt,
|
29 |
+
transforms_cls,
|
30 |
+
translator.vocabs,
|
31 |
+
task=CorpusTask.INFER,
|
32 |
+
copy=translator.copy_attn,
|
33 |
+
)
|
34 |
+
|
35 |
+
infer_iter = IterOnDevice(infer_iter, opt.gpu)
|
36 |
+
|
37 |
+
_, _ = translator._translate(
|
38 |
+
infer_iter,
|
39 |
+
transform=infer_iter.transform,
|
40 |
+
attn_debug=opt.attn_debug,
|
41 |
+
align_debug=opt.align_debug,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def _get_parser():
|
46 |
+
parser = ArgumentParser(description="translate.py")
|
47 |
+
|
48 |
+
opts.config_opts(parser)
|
49 |
+
opts.translate_opts(parser, dynamic=True)
|
50 |
+
return parser
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
parser = _get_parser()
|
55 |
+
opt = parser.parse_args()
|
56 |
+
translate(opt)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
onmt/constants.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define constant values used across the project."""
|
2 |
+
|
3 |
+
|
4 |
+
class DefaultTokens(object):
|
5 |
+
PAD = "<blank>"
|
6 |
+
BOS = "<s>"
|
7 |
+
EOS = "</s>"
|
8 |
+
UNK = "<unk>"
|
9 |
+
MASK = "<mask>"
|
10 |
+
VOCAB_PAD = "averyunlikelytoken"
|
11 |
+
SENT_FULL_STOPS = [".", "?", "!"]
|
12 |
+
PHRASE_TABLE_SEPARATOR = "|||"
|
13 |
+
ALIGNMENT_SEPARATOR = " ||| "
|
14 |
+
SEP = "⦅newline⦆"
|
15 |
+
MASK_BEFORE = "⦅_mask_before_⦆"
|
16 |
+
|
17 |
+
|
18 |
+
class CorpusName(object):
|
19 |
+
VALID = "valid"
|
20 |
+
TRAIN = "train"
|
21 |
+
SAMPLE = "sample"
|
22 |
+
INFER = "infer"
|
23 |
+
|
24 |
+
|
25 |
+
class CorpusTask(object):
|
26 |
+
TRAIN = "train"
|
27 |
+
VALID = "valid"
|
28 |
+
INFER = "infer"
|
29 |
+
|
30 |
+
|
31 |
+
class SubwordMarker(object):
|
32 |
+
SPACER = "▁"
|
33 |
+
JOINER = "■"
|
34 |
+
BEGIN_UPPERCASE = "⦅mrk_begin_case_region_U⦆"
|
35 |
+
END_UPPERCASE = "⦅mrk_end_case_region_U⦆"
|
36 |
+
BEGIN_CASED = "⦅mrk_case_modifier_C⦆"
|
37 |
+
|
38 |
+
|
39 |
+
class ModelTask(object):
|
40 |
+
LANGUAGE_MODEL = "lm"
|
41 |
+
SEQ2SEQ = "seq2seq"
|
onmt/decoders/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module defining decoders."""
|
2 |
+
import os
|
3 |
+
import importlib
|
4 |
+
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
|
5 |
+
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
|
6 |
+
from onmt.decoders.cnn_decoder import CNNDecoder
|
7 |
+
|
8 |
+
|
9 |
+
str2dec = {
|
10 |
+
"rnn": StdRNNDecoder,
|
11 |
+
"ifrnn": InputFeedRNNDecoder,
|
12 |
+
"cnn": CNNDecoder,
|
13 |
+
"transformer": TransformerDecoder,
|
14 |
+
"transformer_lm": TransformerLMDecoder,
|
15 |
+
}
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"DecoderBase",
|
19 |
+
"TransformerDecoder",
|
20 |
+
"StdRNNDecoder",
|
21 |
+
"CNNDecoder",
|
22 |
+
"InputFeedRNNDecoder",
|
23 |
+
"str2dec",
|
24 |
+
"TransformerLMDecoder",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def get_decoders_cls(decoders_names):
|
29 |
+
"""Return valid encoder class indicated in `decoders_names`."""
|
30 |
+
decoders_cls = {}
|
31 |
+
for name in decoders_names:
|
32 |
+
if name not in str2dec:
|
33 |
+
raise ValueError("%s decoder not supported!" % name)
|
34 |
+
decoders_cls[name] = str2dec[name]
|
35 |
+
return decoders_cls
|
36 |
+
|
37 |
+
|
38 |
+
def register_decoder(name):
|
39 |
+
"""Encoder register that can be used to add new encoder class."""
|
40 |
+
|
41 |
+
def register_decoder_cls(cls):
|
42 |
+
if name in str2dec:
|
43 |
+
raise ValueError("Cannot register duplicate decoder ({})".format(name))
|
44 |
+
if not issubclass(cls, DecoderBase):
|
45 |
+
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
|
46 |
+
str2dec[name] = cls
|
47 |
+
__all__.append(cls.__name__) # added to be complete
|
48 |
+
return cls
|
49 |
+
|
50 |
+
return register_decoder_cls
|
51 |
+
|
52 |
+
|
53 |
+
# Auto import python files in this directory
|
54 |
+
decoder_dir = os.path.dirname(__file__)
|
55 |
+
for file in os.listdir(decoder_dir):
|
56 |
+
path = os.path.join(decoder_dir, file)
|
57 |
+
if (
|
58 |
+
not file.startswith("_")
|
59 |
+
and not file.startswith(".")
|
60 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
61 |
+
):
|
62 |
+
file_name = file[: file.find(".py")] if file.endswith(".py") else file
|
63 |
+
module = importlib.import_module("onmt.decoders." + file_name)
|
onmt/decoders/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (3 kB). View file
|
|
onmt/decoders/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.84 kB). View file
|
|
onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc
ADDED
Binary file (7.32 kB). View file
|
|
onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc
ADDED
Binary file (4.02 kB). View file
|
|
onmt/decoders/__pycache__/decoder.cpython-311.pyc
ADDED
Binary file (18.4 kB). View file
|
|
onmt/decoders/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (11.5 kB). View file
|
|
onmt/decoders/__pycache__/ensemble.cpython-311.pyc
ADDED
Binary file (11 kB). View file
|
|
onmt/decoders/__pycache__/ensemble.cpython-38.pyc
ADDED
Binary file (7.13 kB). View file
|
|
onmt/decoders/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (32.9 kB). View file
|
|
onmt/decoders/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (20.4 kB). View file
|
|
onmt/decoders/cnn_decoder.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the CNN Decoder part of
|
2 |
+
"Convolutional Sequence to Sequence Learning"
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from onmt.modules import ConvMultiStepAttention, GlobalAttention
|
8 |
+
from onmt.utils.cnn_factory import shape_transform, GatedConv
|
9 |
+
from onmt.decoders.decoder import DecoderBase
|
10 |
+
|
11 |
+
SCALE_WEIGHT = 0.5**0.5
|
12 |
+
|
13 |
+
|
14 |
+
class CNNDecoder(DecoderBase):
|
15 |
+
"""Decoder based on "Convolutional Sequence to Sequence Learning"
|
16 |
+
:cite:`DBLP:journals/corr/GehringAGYD17`.
|
17 |
+
|
18 |
+
Consists of residual convolutional layers, with ConvMultiStepAttention.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
num_layers,
|
24 |
+
hidden_size,
|
25 |
+
attn_type,
|
26 |
+
copy_attn,
|
27 |
+
cnn_kernel_width,
|
28 |
+
dropout,
|
29 |
+
embeddings,
|
30 |
+
copy_attn_type,
|
31 |
+
):
|
32 |
+
super(CNNDecoder, self).__init__()
|
33 |
+
|
34 |
+
self.cnn_kernel_width = cnn_kernel_width
|
35 |
+
self.embeddings = embeddings
|
36 |
+
|
37 |
+
# Decoder State
|
38 |
+
self.state = {}
|
39 |
+
|
40 |
+
input_size = self.embeddings.embedding_size
|
41 |
+
self.linear = nn.Linear(input_size, hidden_size)
|
42 |
+
self.conv_layers = nn.ModuleList(
|
43 |
+
[
|
44 |
+
GatedConv(hidden_size, cnn_kernel_width, dropout, True)
|
45 |
+
for i in range(num_layers)
|
46 |
+
]
|
47 |
+
)
|
48 |
+
self.attn_layers = nn.ModuleList(
|
49 |
+
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
|
50 |
+
)
|
51 |
+
|
52 |
+
# CNNDecoder has its own attention mechanism.
|
53 |
+
# Set up a separate copy attention layer if needed.
|
54 |
+
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
|
55 |
+
if copy_attn:
|
56 |
+
self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
|
57 |
+
else:
|
58 |
+
self.copy_attn = None
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_opt(cls, opt, embeddings):
|
62 |
+
"""Alternate constructor."""
|
63 |
+
return cls(
|
64 |
+
opt.dec_layers,
|
65 |
+
opt.dec_hid_size,
|
66 |
+
opt.global_attention,
|
67 |
+
opt.copy_attn,
|
68 |
+
opt.cnn_kernel_width,
|
69 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
70 |
+
embeddings,
|
71 |
+
opt.copy_attn_type,
|
72 |
+
)
|
73 |
+
|
74 |
+
def init_state(self, _, enc_out, enc_hidden):
|
75 |
+
"""Init decoder state."""
|
76 |
+
self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
|
77 |
+
self.state["previous_input"] = None
|
78 |
+
|
79 |
+
def map_state(self, fn):
|
80 |
+
self.state["src"] = fn(self.state["src"], 0)
|
81 |
+
if self.state["previous_input"] is not None:
|
82 |
+
self.state["previous_input"] = fn(self.state["previous_input"], 0)
|
83 |
+
|
84 |
+
def detach_state(self):
|
85 |
+
self.state["previous_input"] = self.state["previous_input"].detach()
|
86 |
+
|
87 |
+
def forward(self, tgt, enc_out, step=None, **kwargs):
|
88 |
+
"""See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
|
89 |
+
|
90 |
+
if self.state["previous_input"] is not None:
|
91 |
+
tgt = torch.cat([self.state["previous_input"], tgt], 1)
|
92 |
+
|
93 |
+
dec_outs = []
|
94 |
+
attns = {"std": []}
|
95 |
+
if self.copy_attn is not None:
|
96 |
+
attns["copy"] = []
|
97 |
+
|
98 |
+
emb = self.embeddings(tgt)
|
99 |
+
assert emb.dim() == 3 # batch x len x embedding_dim
|
100 |
+
|
101 |
+
tgt_emb = emb
|
102 |
+
# The output of CNNEncoder.
|
103 |
+
enc_out_t = enc_out
|
104 |
+
# The combination of output of CNNEncoder and source embeddings.
|
105 |
+
enc_out_c = self.state["src"]
|
106 |
+
|
107 |
+
emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
|
108 |
+
linear_out = self.linear(emb_reshape)
|
109 |
+
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
|
110 |
+
x = shape_transform(x)
|
111 |
+
|
112 |
+
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
|
113 |
+
|
114 |
+
pad = pad.type_as(x)
|
115 |
+
base_target_emb = x
|
116 |
+
|
117 |
+
for conv, attention in zip(self.conv_layers, self.attn_layers):
|
118 |
+
new_target_input = torch.cat([pad, x], 2)
|
119 |
+
out = conv(new_target_input)
|
120 |
+
c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
|
121 |
+
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
|
122 |
+
|
123 |
+
dec_outs = x.squeeze(3).transpose(1, 2)
|
124 |
+
|
125 |
+
# Process the result and update the attentions.
|
126 |
+
if self.state["previous_input"] is not None:
|
127 |
+
dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
|
128 |
+
attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
|
129 |
+
attn = torch.stack([attn])
|
130 |
+
attns["std"] = attn
|
131 |
+
if self.copy_attn is not None:
|
132 |
+
attns["copy"] = attn
|
133 |
+
|
134 |
+
# Update the state.
|
135 |
+
self.state["previous_input"] = tgt
|
136 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
137 |
+
return dec_outs, attns
|
138 |
+
|
139 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
140 |
+
for layer in self.conv_layers:
|
141 |
+
layer.dropout.p = dropout
|
onmt/decoders/decoder.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from onmt.modules.stacked_rnn import StackedLSTM, StackedGRU
|
5 |
+
from onmt.modules import context_gate_factory, GlobalAttention
|
6 |
+
from onmt.utils.rnn_factory import rnn_factory
|
7 |
+
|
8 |
+
|
9 |
+
class DecoderBase(nn.Module):
|
10 |
+
"""Abstract class for decoders.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
attentional (bool): The decoder returns non-empty attention.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, attentional=True):
|
17 |
+
super(DecoderBase, self).__init__()
|
18 |
+
self.attentional = attentional
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def from_opt(cls, opt, embeddings):
|
22 |
+
"""Alternate constructor.
|
23 |
+
|
24 |
+
Subclasses should override this method.
|
25 |
+
"""
|
26 |
+
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
|
30 |
+
class RNNDecoderBase(DecoderBase):
|
31 |
+
"""Base recurrent attention-based decoder class.
|
32 |
+
|
33 |
+
Specifies the interface used by different decoder types
|
34 |
+
and required by :class:`~onmt.models.NMTModel`.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
rnn_type (str):
|
38 |
+
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
|
39 |
+
bidirectional_encoder (bool) : use with a bidirectional encoder
|
40 |
+
num_layers (int) : number of stacked layers
|
41 |
+
hidden_size (int) : hidden size of each layer
|
42 |
+
attn_type (str) : see :class:`~onmt.modules.GlobalAttention`
|
43 |
+
attn_func (str) : see :class:`~onmt.modules.GlobalAttention`
|
44 |
+
coverage_attn (str): see :class:`~onmt.modules.GlobalAttention`
|
45 |
+
context_gate (str): see :class:`~onmt.modules.ContextGate`
|
46 |
+
copy_attn (bool): setup a separate copy attention mechanism
|
47 |
+
dropout (float) : dropout value for :class:`torch.nn.Dropout`
|
48 |
+
embeddings (onmt.modules.Embeddings): embedding module to use
|
49 |
+
reuse_copy_attn (bool): reuse the attention for copying
|
50 |
+
copy_attn_type (str): The copy attention style. See
|
51 |
+
:class:`~onmt.modules.GlobalAttention`.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
rnn_type,
|
57 |
+
bidirectional_encoder,
|
58 |
+
num_layers,
|
59 |
+
hidden_size,
|
60 |
+
attn_type="general",
|
61 |
+
attn_func="softmax",
|
62 |
+
coverage_attn=False,
|
63 |
+
context_gate=None,
|
64 |
+
copy_attn=False,
|
65 |
+
dropout=0.0,
|
66 |
+
embeddings=None,
|
67 |
+
reuse_copy_attn=False,
|
68 |
+
copy_attn_type="general",
|
69 |
+
):
|
70 |
+
super(RNNDecoderBase, self).__init__(
|
71 |
+
attentional=attn_type != "none" and attn_type is not None
|
72 |
+
)
|
73 |
+
|
74 |
+
self.bidirectional_encoder = bidirectional_encoder
|
75 |
+
self.num_layers = num_layers
|
76 |
+
self.hidden_size = hidden_size
|
77 |
+
self.embeddings = embeddings
|
78 |
+
self.dropout = nn.Dropout(dropout)
|
79 |
+
|
80 |
+
# Decoder state
|
81 |
+
self.state = {}
|
82 |
+
|
83 |
+
# Build the RNN.
|
84 |
+
self.rnn = self._build_rnn(
|
85 |
+
rnn_type,
|
86 |
+
input_size=self._input_size,
|
87 |
+
hidden_size=hidden_size,
|
88 |
+
num_layers=num_layers,
|
89 |
+
dropout=dropout,
|
90 |
+
)
|
91 |
+
|
92 |
+
# Set up the context gate.
|
93 |
+
self.context_gate = None
|
94 |
+
if context_gate is not None:
|
95 |
+
self.context_gate = context_gate_factory(
|
96 |
+
context_gate, self._input_size, hidden_size, hidden_size, hidden_size
|
97 |
+
)
|
98 |
+
|
99 |
+
# Set up the standard attention.
|
100 |
+
self._coverage = coverage_attn
|
101 |
+
if not self.attentional:
|
102 |
+
if self._coverage:
|
103 |
+
raise ValueError("Cannot use coverage term with no attention.")
|
104 |
+
self.attn = None
|
105 |
+
else:
|
106 |
+
self.attn = GlobalAttention(
|
107 |
+
hidden_size,
|
108 |
+
coverage=coverage_attn,
|
109 |
+
attn_type=attn_type,
|
110 |
+
attn_func=attn_func,
|
111 |
+
)
|
112 |
+
|
113 |
+
if copy_attn and not reuse_copy_attn:
|
114 |
+
if copy_attn_type == "none" or copy_attn_type is None:
|
115 |
+
raise ValueError("Cannot use copy_attn with copy_attn_type none")
|
116 |
+
self.copy_attn = GlobalAttention(
|
117 |
+
hidden_size, attn_type=copy_attn_type, attn_func=attn_func
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
self.copy_attn = None
|
121 |
+
|
122 |
+
self._reuse_copy_attn = reuse_copy_attn and copy_attn
|
123 |
+
if self._reuse_copy_attn and not self.attentional:
|
124 |
+
raise ValueError("Cannot reuse copy attention with no attention.")
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def from_opt(cls, opt, embeddings):
|
128 |
+
"""Alternate constructor."""
|
129 |
+
return cls(
|
130 |
+
opt.rnn_type,
|
131 |
+
opt.brnn,
|
132 |
+
opt.dec_layers,
|
133 |
+
opt.dec_hid_size,
|
134 |
+
opt.global_attention,
|
135 |
+
opt.global_attention_function,
|
136 |
+
opt.coverage_attn,
|
137 |
+
opt.context_gate,
|
138 |
+
opt.copy_attn,
|
139 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
140 |
+
embeddings,
|
141 |
+
opt.reuse_copy_attn,
|
142 |
+
opt.copy_attn_type,
|
143 |
+
)
|
144 |
+
|
145 |
+
def init_state(self, src, _, enc_final_hs):
|
146 |
+
"""Initialize decoder state with last state of the encoder."""
|
147 |
+
|
148 |
+
def _fix_enc_hidden(hidden):
|
149 |
+
# The encoder hidden is (layers*directions) x batch x dim.
|
150 |
+
# We need to convert it to layers x batch x (directions*dim).
|
151 |
+
if self.bidirectional_encoder:
|
152 |
+
hidden = torch.cat(
|
153 |
+
[hidden[0 : hidden.size(0) : 2], hidden[1 : hidden.size(0) : 2]], 2
|
154 |
+
)
|
155 |
+
return hidden
|
156 |
+
|
157 |
+
if isinstance(enc_final_hs, tuple): # LSTM
|
158 |
+
self.state["hidden"] = tuple(
|
159 |
+
_fix_enc_hidden(enc_hid) for enc_hid in enc_final_hs
|
160 |
+
)
|
161 |
+
else: # GRU
|
162 |
+
self.state["hidden"] = (_fix_enc_hidden(enc_final_hs),)
|
163 |
+
|
164 |
+
# Init the input feed.
|
165 |
+
batch_size = self.state["hidden"][0].size(1)
|
166 |
+
|
167 |
+
h_size = (batch_size, self.hidden_size)
|
168 |
+
self.state["input_feed"] = (
|
169 |
+
self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
|
170 |
+
)
|
171 |
+
|
172 |
+
self.state["coverage"] = None
|
173 |
+
|
174 |
+
def map_state(self, fn):
|
175 |
+
self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
|
176 |
+
self.state["input_feed"] = fn(self.state["input_feed"], 1)
|
177 |
+
if self._coverage and self.state["coverage"] is not None:
|
178 |
+
self.state["coverage"] = fn(self.state["coverage"], 1)
|
179 |
+
|
180 |
+
def detach_state(self):
|
181 |
+
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
|
182 |
+
self.state["input_feed"] = self.state["input_feed"].detach()
|
183 |
+
if self._coverage and self.state["coverage"] is not None:
|
184 |
+
self.state["coverage"] = self.state["coverage"].detach()
|
185 |
+
|
186 |
+
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
|
187 |
+
"""
|
188 |
+
Args:
|
189 |
+
tgt (LongTensor): sequences of padded tokens
|
190 |
+
``(batch, tgt_len, nfeats)``.
|
191 |
+
enc_out (FloatTensor): vectors from the encoder
|
192 |
+
``(batch, src_len, hidden)``.
|
193 |
+
src_len (LongTensor): the padded source lengths
|
194 |
+
``(batch,)``.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
(FloatTensor, dict[str, FloatTensor]):
|
198 |
+
|
199 |
+
* dec_outs: output from the decoder (after attn)
|
200 |
+
``(batch, tgt_len, hidden)``.
|
201 |
+
* attns: distribution over src at each tgt
|
202 |
+
``(batch, tgt_len, src_len)``.
|
203 |
+
"""
|
204 |
+
dec_state, dec_outs, attns = self._run_forward_pass(
|
205 |
+
tgt, enc_out, src_len=src_len
|
206 |
+
)
|
207 |
+
|
208 |
+
# Update the state with the result.
|
209 |
+
if not isinstance(dec_state, tuple):
|
210 |
+
dec_state = (dec_state,)
|
211 |
+
self.state["hidden"] = dec_state
|
212 |
+
|
213 |
+
# Concatenates sequence of tensors along a new dimension.
|
214 |
+
# NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
|
215 |
+
# (in particular in case of SRU) it was not raising error in 0.3
|
216 |
+
# since stack(Variable) was allowed.
|
217 |
+
# In 0.4, SRU returns a tensor that shouldn't be stacke
|
218 |
+
if type(dec_outs) == list:
|
219 |
+
dec_outs = torch.stack(dec_outs, dim=1)
|
220 |
+
for k in attns:
|
221 |
+
if type(attns[k]) == list:
|
222 |
+
attns[k] = torch.stack(attns[k])
|
223 |
+
|
224 |
+
self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0)
|
225 |
+
self.state["coverage"] = None
|
226 |
+
if "coverage" in attns:
|
227 |
+
self.state["coverage"] = attns["coverage"][-1, :, :].unsqueeze(0)
|
228 |
+
|
229 |
+
return dec_outs, attns
|
230 |
+
|
231 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
232 |
+
self.dropout.p = dropout
|
233 |
+
self.embeddings.update_dropout(dropout)
|
234 |
+
|
235 |
+
|
236 |
+
class StdRNNDecoder(RNNDecoderBase):
|
237 |
+
"""Standard fully batched RNN decoder with attention.
|
238 |
+
|
239 |
+
Faster implementation, uses CuDNN for implementation.
|
240 |
+
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
|
241 |
+
|
242 |
+
|
243 |
+
Based around the approach from
|
244 |
+
"Neural Machine Translation By Jointly Learning To Align and Translate"
|
245 |
+
:cite:`Bahdanau2015`
|
246 |
+
|
247 |
+
|
248 |
+
Implemented without input_feeding and currently with no `coverage_attn`
|
249 |
+
or `copy_attn` support.
|
250 |
+
"""
|
251 |
+
|
252 |
+
def _run_forward_pass(self, tgt, enc_out, src_len=None):
|
253 |
+
"""
|
254 |
+
Private helper for running the specific RNN forward pass.
|
255 |
+
Must be overriden by all subclasses.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
tgt (LongTensor): a sequence of input tokens tensors
|
259 |
+
``(batch, tgt_len, nfeats)``.
|
260 |
+
enc_out (FloatTensor): output(tensor sequence) from the
|
261 |
+
encoder RNN of size ``(batch, src_len, hidden_size)``.
|
262 |
+
src_len (LongTensor): the source enc_out lengths.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
(Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):
|
266 |
+
|
267 |
+
* dec_state: final hidden state from the decoder.
|
268 |
+
* dec_outs: an array of output of every time
|
269 |
+
step from the decoder.
|
270 |
+
* attns: a dictionary of different
|
271 |
+
type of attention Tensor array of every time
|
272 |
+
step from the decoder.
|
273 |
+
"""
|
274 |
+
|
275 |
+
assert self.copy_attn is None # TODO, no support yet.
|
276 |
+
assert not self._coverage # TODO, no support yet.
|
277 |
+
|
278 |
+
attns = {}
|
279 |
+
emb = self.embeddings(tgt)
|
280 |
+
|
281 |
+
if isinstance(self.rnn, nn.GRU):
|
282 |
+
rnn_out, dec_state = self.rnn(emb, self.state["hidden"][0])
|
283 |
+
else:
|
284 |
+
rnn_out, dec_state = self.rnn(emb, self.state["hidden"])
|
285 |
+
|
286 |
+
tgt_batch, tgt_len, _ = tgt.size()
|
287 |
+
|
288 |
+
# Calculate the attention.
|
289 |
+
if not self.attentional:
|
290 |
+
dec_outs = rnn_out
|
291 |
+
else:
|
292 |
+
dec_outs, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
|
293 |
+
attns["std"] = p_attn
|
294 |
+
|
295 |
+
# Calculate the context gate.
|
296 |
+
if self.context_gate is not None:
|
297 |
+
dec_outs = self.context_gate(
|
298 |
+
emb.view(-1, emb.size(2)),
|
299 |
+
rnn_out.view(-1, rnn_out.size(2)),
|
300 |
+
dec_outs.view(-1, dec_outs.size(2)),
|
301 |
+
)
|
302 |
+
dec_outs = dec_outs.view(tgt_batch, tgt_len, self.hidden_size)
|
303 |
+
|
304 |
+
dec_outs = self.dropout(dec_outs)
|
305 |
+
|
306 |
+
return dec_state, dec_outs, attns
|
307 |
+
|
308 |
+
def _build_rnn(self, rnn_type, **kwargs):
|
309 |
+
rnn, _ = rnn_factory(rnn_type, **kwargs)
|
310 |
+
return rnn
|
311 |
+
|
312 |
+
@property
|
313 |
+
def _input_size(self):
|
314 |
+
return self.embeddings.embedding_size
|
315 |
+
|
316 |
+
|
317 |
+
class InputFeedRNNDecoder(RNNDecoderBase):
|
318 |
+
"""Input feeding based decoder.
|
319 |
+
|
320 |
+
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
|
321 |
+
|
322 |
+
Based around the input feeding approach from
|
323 |
+
"Effective Approaches to Attention-based Neural Machine Translation"
|
324 |
+
:cite:`Luong2015`
|
325 |
+
|
326 |
+
"""
|
327 |
+
|
328 |
+
def _run_forward_pass(self, tgt, enc_out, src_len=None):
|
329 |
+
"""
|
330 |
+
See StdRNNDecoder._run_forward_pass() for description
|
331 |
+
of arguments and return values.
|
332 |
+
"""
|
333 |
+
# Additional args check.
|
334 |
+
input_feed = self.state["input_feed"].squeeze(0)
|
335 |
+
|
336 |
+
dec_outs = []
|
337 |
+
attns = {}
|
338 |
+
if self.attn is not None:
|
339 |
+
attns["std"] = []
|
340 |
+
if self.copy_attn is not None or self._reuse_copy_attn:
|
341 |
+
attns["copy"] = []
|
342 |
+
if self._coverage:
|
343 |
+
attns["coverage"] = []
|
344 |
+
|
345 |
+
emb = self.embeddings(tgt)
|
346 |
+
assert emb.dim() == 3 # batch x len x embedding_dim
|
347 |
+
|
348 |
+
dec_state = self.state["hidden"]
|
349 |
+
|
350 |
+
coverage = (
|
351 |
+
self.state["coverage"].squeeze(0)
|
352 |
+
if self.state["coverage"] is not None
|
353 |
+
else None
|
354 |
+
)
|
355 |
+
|
356 |
+
# Input feed concatenates hidden state with
|
357 |
+
# input at every time step.
|
358 |
+
for emb_t in emb.split(1, dim=1):
|
359 |
+
dec_in = torch.cat([emb_t.squeeze(1), input_feed], 1)
|
360 |
+
rnn_out, dec_state = self.rnn(dec_in, dec_state)
|
361 |
+
if self.attentional:
|
362 |
+
dec_out, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
|
363 |
+
attns["std"].append(p_attn)
|
364 |
+
else:
|
365 |
+
dec_out = rnn_out
|
366 |
+
if self.context_gate is not None:
|
367 |
+
# TODO: context gate should be employed
|
368 |
+
# instead of second RNN transform.
|
369 |
+
dec_out = self.context_gate(dec_in, rnn_out, dec_out)
|
370 |
+
dec_out = self.dropout(dec_out)
|
371 |
+
input_feed = dec_out
|
372 |
+
|
373 |
+
dec_outs += [dec_out]
|
374 |
+
|
375 |
+
# Update the coverage attention.
|
376 |
+
# attns["coverage"] is actually c^(t+1) of See et al(2017)
|
377 |
+
# 1-index shifted
|
378 |
+
if self._coverage:
|
379 |
+
coverage = p_attn if coverage is None else p_attn + coverage
|
380 |
+
attns["coverage"] += [coverage]
|
381 |
+
|
382 |
+
if self.copy_attn is not None:
|
383 |
+
_, copy_attn = self.copy_attn(dec_out, enc_out)
|
384 |
+
attns["copy"] += [copy_attn]
|
385 |
+
elif self._reuse_copy_attn:
|
386 |
+
attns["copy"] = attns["std"]
|
387 |
+
|
388 |
+
return dec_state, dec_outs, attns
|
389 |
+
|
390 |
+
def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout):
|
391 |
+
assert rnn_type != "SRU", (
|
392 |
+
"SRU doesn't support input feed! " "Please set -input_feed 0!"
|
393 |
+
)
|
394 |
+
stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU
|
395 |
+
return stacked_cell(num_layers, input_size, hidden_size, dropout)
|
396 |
+
|
397 |
+
@property
|
398 |
+
def _input_size(self):
|
399 |
+
"""Using input feed by concatenating input with attention vectors."""
|
400 |
+
return self.embeddings.embedding_size + self.hidden_size
|
401 |
+
|
402 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
403 |
+
self.dropout.p = dropout
|
404 |
+
self.rnn.dropout.p = dropout
|
405 |
+
self.embeddings.update_dropout(dropout)
|
onmt/decoders/ensemble.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Ensemble decoding.
|
2 |
+
|
3 |
+
Decodes using multiple models simultaneously,
|
4 |
+
combining their prediction distributions by averaging.
|
5 |
+
All models in the ensemble must share a target vocabulary.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from onmt.encoders.encoder import EncoderBase
|
12 |
+
from onmt.decoders.decoder import DecoderBase
|
13 |
+
from onmt.models import NMTModel
|
14 |
+
import onmt.model_builder
|
15 |
+
|
16 |
+
|
17 |
+
class EnsembleDecoderOutput(object):
|
18 |
+
"""Wrapper around multiple decoder final hidden states."""
|
19 |
+
|
20 |
+
def __init__(self, model_dec_outs):
|
21 |
+
self.model_dec_outs = tuple(model_dec_outs)
|
22 |
+
|
23 |
+
def squeeze(self, dim=None):
|
24 |
+
"""Delegate squeeze to avoid modifying
|
25 |
+
:func:`onmt.translate.translator.Translator.translate_batch()`
|
26 |
+
"""
|
27 |
+
return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
return self.model_dec_outs[index]
|
31 |
+
|
32 |
+
|
33 |
+
class EnsembleEncoder(EncoderBase):
|
34 |
+
"""Dummy Encoder that delegates to individual real Encoders."""
|
35 |
+
|
36 |
+
def __init__(self, model_encoders):
|
37 |
+
super(EnsembleEncoder, self).__init__()
|
38 |
+
self.model_encoders = nn.ModuleList(model_encoders)
|
39 |
+
|
40 |
+
def forward(self, src, src_len=None):
|
41 |
+
enc_out, enc_final_hs, _ = zip(
|
42 |
+
*[model_encoder(src, src_len) for model_encoder in self.model_encoders]
|
43 |
+
)
|
44 |
+
return enc_out, enc_final_hs, src_len
|
45 |
+
|
46 |
+
|
47 |
+
class EnsembleDecoder(DecoderBase):
|
48 |
+
"""Dummy Decoder that delegates to individual real Decoders."""
|
49 |
+
|
50 |
+
def __init__(self, model_decoders):
|
51 |
+
model_decoders = nn.ModuleList(model_decoders)
|
52 |
+
attentional = any([dec.attentional for dec in model_decoders])
|
53 |
+
super(EnsembleDecoder, self).__init__(attentional)
|
54 |
+
self.model_decoders = model_decoders
|
55 |
+
|
56 |
+
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
|
57 |
+
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
|
58 |
+
# src_len is a single tensor shared between all models.
|
59 |
+
# This assumption will not hold if Translator is modified
|
60 |
+
# to calculate src_len as something other than the length
|
61 |
+
# of the input.
|
62 |
+
dec_outs, attns = zip(
|
63 |
+
*[
|
64 |
+
model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
|
65 |
+
for i, model_decoder in enumerate(self.model_decoders)
|
66 |
+
]
|
67 |
+
)
|
68 |
+
mean_attns = self.combine_attns(attns)
|
69 |
+
return EnsembleDecoderOutput(dec_outs), mean_attns
|
70 |
+
|
71 |
+
def combine_attns(self, attns):
|
72 |
+
result = {}
|
73 |
+
for key in attns[0].keys():
|
74 |
+
result[key] = torch.stack(
|
75 |
+
[attn[key] for attn in attns if attn[key] is not None]
|
76 |
+
).mean(0)
|
77 |
+
return result
|
78 |
+
|
79 |
+
def init_state(self, src, enc_out, enc_hidden):
|
80 |
+
"""See :obj:`RNNDecoderBase.init_state()`"""
|
81 |
+
for i, model_decoder in enumerate(self.model_decoders):
|
82 |
+
model_decoder.init_state(src, enc_out[i], enc_hidden[i])
|
83 |
+
|
84 |
+
def map_state(self, fn):
|
85 |
+
for model_decoder in self.model_decoders:
|
86 |
+
model_decoder.map_state(fn)
|
87 |
+
|
88 |
+
|
89 |
+
class EnsembleGenerator(nn.Module):
|
90 |
+
"""
|
91 |
+
Dummy Generator that delegates to individual real Generators,
|
92 |
+
and then averages the resulting target distributions.
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, model_generators, raw_probs=False):
|
96 |
+
super(EnsembleGenerator, self).__init__()
|
97 |
+
self.model_generators = nn.ModuleList(model_generators)
|
98 |
+
self._raw_probs = raw_probs
|
99 |
+
|
100 |
+
def forward(self, hidden, attn=None, src_map=None):
|
101 |
+
"""
|
102 |
+
Compute a distribution over the target dictionary
|
103 |
+
by averaging distributions from models in the ensemble.
|
104 |
+
All models in the ensemble must share a target vocabulary.
|
105 |
+
"""
|
106 |
+
distributions = torch.stack(
|
107 |
+
[
|
108 |
+
mg(h) if attn is None else mg(h, attn, src_map)
|
109 |
+
for h, mg in zip(hidden, self.model_generators)
|
110 |
+
]
|
111 |
+
)
|
112 |
+
if self._raw_probs:
|
113 |
+
return torch.log(torch.exp(distributions).mean(0))
|
114 |
+
else:
|
115 |
+
return distributions.mean(0)
|
116 |
+
|
117 |
+
|
118 |
+
class EnsembleModel(NMTModel):
|
119 |
+
"""Dummy NMTModel wrapping individual real NMTModels."""
|
120 |
+
|
121 |
+
def __init__(self, models, raw_probs=False):
|
122 |
+
encoder = EnsembleEncoder(model.encoder for model in models)
|
123 |
+
decoder = EnsembleDecoder(model.decoder for model in models)
|
124 |
+
super(EnsembleModel, self).__init__(encoder, decoder)
|
125 |
+
self.generator = EnsembleGenerator(
|
126 |
+
[model.generator for model in models], raw_probs
|
127 |
+
)
|
128 |
+
self.models = nn.ModuleList(models)
|
129 |
+
|
130 |
+
|
131 |
+
def load_test_model(opt, device_id=0):
|
132 |
+
"""Read in multiple models for ensemble."""
|
133 |
+
shared_vocabs = None
|
134 |
+
shared_model_opt = None
|
135 |
+
models = []
|
136 |
+
for model_path in opt.models:
|
137 |
+
vocabs, model, model_opt = onmt.model_builder.load_test_model(
|
138 |
+
opt, device_id, model_path=model_path
|
139 |
+
)
|
140 |
+
if shared_vocabs is None:
|
141 |
+
shared_vocabs = vocabs
|
142 |
+
else:
|
143 |
+
assert (
|
144 |
+
shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
|
145 |
+
), "Ensemble models must use the same vocabs "
|
146 |
+
models.append(model)
|
147 |
+
if shared_model_opt is None:
|
148 |
+
shared_model_opt = model_opt
|
149 |
+
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
|
150 |
+
return shared_vocabs, ensemble_model, shared_model_opt
|
onmt/decoders/transformer.py
ADDED
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of "Attention is All You Need" and of
|
3 |
+
subsequent transformer based architectures
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from onmt.decoders.decoder import DecoderBase
|
9 |
+
from onmt.modules import MultiHeadedAttention, AverageAttention
|
10 |
+
from onmt.modules.position_ffn import PositionwiseFeedForward
|
11 |
+
from onmt.modules.position_ffn import ActivationFunction
|
12 |
+
from onmt.utils.misc import sequence_mask
|
13 |
+
from onmt.modules.rmsnorm import RMSNorm
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerDecoderLayerBase(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
d_model,
|
20 |
+
heads,
|
21 |
+
d_ff,
|
22 |
+
dropout,
|
23 |
+
attention_dropout,
|
24 |
+
self_attn_type="scaled-dot",
|
25 |
+
max_relative_positions=0,
|
26 |
+
relative_positions_buckets=0,
|
27 |
+
aan_useffn=False,
|
28 |
+
full_context_alignment=False,
|
29 |
+
alignment_heads=0,
|
30 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
31 |
+
add_qkvbias=False,
|
32 |
+
num_kv=0,
|
33 |
+
add_ffnbias=True,
|
34 |
+
parallel_residual=False,
|
35 |
+
shared_layer_norm=False,
|
36 |
+
layer_norm="standard",
|
37 |
+
norm_eps=1e-6,
|
38 |
+
use_ckpting=[],
|
39 |
+
parallel_gpu=1,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
d_model (int): the dimension of keys/values/queries in
|
44 |
+
:class:`MultiHeadedAttention`, also the input size of
|
45 |
+
the first-layer of the :class:`PositionwiseFeedForward`.
|
46 |
+
heads (int): the number of heads for MultiHeadedAttention.
|
47 |
+
d_ff (int): the second-layer of the
|
48 |
+
:class:`PositionwiseFeedForward`.
|
49 |
+
dropout (float): dropout in residual, self-attn(dot) and
|
50 |
+
feed-forward
|
51 |
+
attention_dropout (float): dropout in context_attn (and
|
52 |
+
self-attn(avg))
|
53 |
+
self_attn_type (string): type of self-attention scaled-dot,
|
54 |
+
average
|
55 |
+
max_relative_positions (int):
|
56 |
+
Max distance between inputs in relative positions
|
57 |
+
representations
|
58 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
59 |
+
full_context_alignment (bool):
|
60 |
+
whether enable an extra full context decoder forward for
|
61 |
+
alignment
|
62 |
+
alignment_heads (int):
|
63 |
+
N. of cross attention heads to use for alignment guiding
|
64 |
+
pos_ffn_activation_fn (ActivationFunction):
|
65 |
+
activation function choice for PositionwiseFeedForward layer
|
66 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
67 |
+
layer_norm (string): type of layer normalization standard/rms
|
68 |
+
norm_eps (float): layer norm epsilon
|
69 |
+
|
70 |
+
"""
|
71 |
+
super(TransformerDecoderLayerBase, self).__init__()
|
72 |
+
|
73 |
+
self.self_attn_type = self_attn_type
|
74 |
+
if self_attn_type == "scaled-dot":
|
75 |
+
self.self_attn = MultiHeadedAttention(
|
76 |
+
heads,
|
77 |
+
d_model,
|
78 |
+
dropout=attention_dropout,
|
79 |
+
max_relative_positions=max_relative_positions,
|
80 |
+
relative_positions_buckets=relative_positions_buckets,
|
81 |
+
attn_type="self",
|
82 |
+
add_qkvbias=add_qkvbias,
|
83 |
+
num_kv=num_kv,
|
84 |
+
use_ckpting=use_ckpting,
|
85 |
+
parallel_gpu=parallel_gpu,
|
86 |
+
)
|
87 |
+
elif self_attn_type == "average":
|
88 |
+
self.self_attn = AverageAttention(
|
89 |
+
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
|
90 |
+
)
|
91 |
+
|
92 |
+
self.feed_forward = PositionwiseFeedForward(
|
93 |
+
d_model,
|
94 |
+
d_ff,
|
95 |
+
dropout,
|
96 |
+
pos_ffn_activation_fn,
|
97 |
+
add_ffnbias,
|
98 |
+
parallel_residual,
|
99 |
+
layer_norm,
|
100 |
+
norm_eps,
|
101 |
+
use_ckpting=use_ckpting,
|
102 |
+
parallel_gpu=parallel_gpu,
|
103 |
+
)
|
104 |
+
self.parallel_residual = parallel_residual
|
105 |
+
self.shared_layer_norm = shared_layer_norm
|
106 |
+
if layer_norm == "standard":
|
107 |
+
self.layer_norm_1 = nn.LayerNorm(d_model, eps=norm_eps)
|
108 |
+
if parallel_residual and not shared_layer_norm:
|
109 |
+
self.layer_norm_res = nn.LayerNorm(d_model, eps=norm_eps)
|
110 |
+
elif layer_norm == "rms":
|
111 |
+
self.layer_norm_1 = RMSNorm(d_model, eps=norm_eps)
|
112 |
+
if parallel_residual and not shared_layer_norm:
|
113 |
+
self.layer_norm_res = RMSNorm(d_model, eps=norm_eps)
|
114 |
+
else:
|
115 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
116 |
+
|
117 |
+
self.dropout = nn.Dropout(dropout)
|
118 |
+
self.full_context_alignment = full_context_alignment
|
119 |
+
self.alignment_heads = alignment_heads
|
120 |
+
|
121 |
+
def forward(self, *args, **kwargs):
|
122 |
+
"""Extend `_forward` for (possibly) multiple decoder pass:
|
123 |
+
Always a default (future masked) decoder forward pass,
|
124 |
+
Possibly a second future aware decoder pass for joint learn
|
125 |
+
full context alignement, :cite:`garg2019jointly`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
* All arguments of _forward, of which
|
129 |
+
with_align (bool): needed to compute attn_align
|
130 |
+
return_attn (bool): to force MHA to return attns
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
(FloatTensor, FloatTensor, FloatTensor or None):
|
134 |
+
|
135 |
+
* layer_out ``(batch_size, T, model_dim)``
|
136 |
+
* top_attn ``(batch_size, T, src_len)``
|
137 |
+
* attn_align ``(batch_size, T, src_len)`` or None
|
138 |
+
"""
|
139 |
+
with_align = kwargs.pop("with_align", False)
|
140 |
+
layer_out, attns = self._forward(*args, **kwargs)
|
141 |
+
top_attn = None if attns is None else attns[:, 0, :, :].contiguous()
|
142 |
+
attn_align = None
|
143 |
+
if with_align:
|
144 |
+
if self.full_context_alignment:
|
145 |
+
# return _, (B, Q_len, K_len)
|
146 |
+
_, attns = self._forward(*args, **kwargs, future=True)
|
147 |
+
|
148 |
+
if self.alignment_heads > 0:
|
149 |
+
attns = attns[:, : self.alignment_heads, :, :].contiguous()
|
150 |
+
# layer average attention across heads, get ``(B, Q, K)``
|
151 |
+
# Case 1: no full_context, no align heads -> layer avg baseline
|
152 |
+
# Case 2: no full_context, 1 align heads -> guided align
|
153 |
+
# Case 3: full_context, 1 align heads -> full cte guided align
|
154 |
+
attn_align = attns.mean(dim=1)
|
155 |
+
return layer_out, top_attn, attn_align
|
156 |
+
|
157 |
+
def update_dropout(self, dropout, attention_dropout):
|
158 |
+
self.self_attn.update_dropout(attention_dropout)
|
159 |
+
self.feed_forward.update_dropout(dropout)
|
160 |
+
self.dropout.p = dropout
|
161 |
+
|
162 |
+
def _forward(self, *args, **kwargs):
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
def _compute_dec_mask(self, tgt_pad_mask, future):
|
166 |
+
tgt_len = tgt_pad_mask.size(-1)
|
167 |
+
if not future: # apply future_mask, result mask in (B, T, T)
|
168 |
+
future_mask = torch.ones(
|
169 |
+
[tgt_len, tgt_len],
|
170 |
+
device=tgt_pad_mask.device,
|
171 |
+
dtype=torch.uint8,
|
172 |
+
)
|
173 |
+
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
174 |
+
# BoolTensor was introduced in pytorch 1.2
|
175 |
+
try:
|
176 |
+
future_mask = future_mask.bool()
|
177 |
+
except AttributeError:
|
178 |
+
pass
|
179 |
+
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
180 |
+
else: # only mask padding, result mask in (B, 1, T)
|
181 |
+
dec_mask = tgt_pad_mask
|
182 |
+
return dec_mask
|
183 |
+
|
184 |
+
def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
|
185 |
+
if self.self_attn_type == "scaled-dot":
|
186 |
+
return self.self_attn(
|
187 |
+
norm_layer_in,
|
188 |
+
norm_layer_in,
|
189 |
+
norm_layer_in,
|
190 |
+
mask=dec_mask,
|
191 |
+
step=step,
|
192 |
+
return_attn=return_attn,
|
193 |
+
)
|
194 |
+
elif self.self_attn_type == "average":
|
195 |
+
return self.self_attn(norm_layer_in, mask=dec_mask, step=step)
|
196 |
+
else:
|
197 |
+
raise ValueError(f"self attention {type(self.self_attn)} not supported")
|
198 |
+
|
199 |
+
|
200 |
+
class TransformerDecoderLayer(TransformerDecoderLayerBase):
|
201 |
+
"""Transformer Decoder layer block in Pre-Norm style.
|
202 |
+
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
|
203 |
+
providing better converge speed and performance. This is also the actual
|
204 |
+
implementation in tensor2tensor and also avalable in fairseq.
|
205 |
+
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
|
206 |
+
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
d_model,
|
212 |
+
heads,
|
213 |
+
d_ff,
|
214 |
+
dropout,
|
215 |
+
attention_dropout,
|
216 |
+
self_attn_type="scaled-dot",
|
217 |
+
max_relative_positions=0,
|
218 |
+
relative_positions_buckets=0,
|
219 |
+
aan_useffn=False,
|
220 |
+
full_context_alignment=False,
|
221 |
+
alignment_heads=0,
|
222 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
223 |
+
add_qkvbias=False,
|
224 |
+
num_kv=0,
|
225 |
+
add_ffnbias=True,
|
226 |
+
parallel_residual=False,
|
227 |
+
shared_layer_norm=False,
|
228 |
+
layer_norm="standard",
|
229 |
+
norm_eps=1e-6,
|
230 |
+
use_ckpting=[],
|
231 |
+
parallel_gpu=1,
|
232 |
+
):
|
233 |
+
"""
|
234 |
+
Args:
|
235 |
+
See TransformerDecoderLayerBase
|
236 |
+
"""
|
237 |
+
super(TransformerDecoderLayer, self).__init__(
|
238 |
+
d_model,
|
239 |
+
heads,
|
240 |
+
d_ff,
|
241 |
+
dropout,
|
242 |
+
attention_dropout,
|
243 |
+
self_attn_type,
|
244 |
+
max_relative_positions,
|
245 |
+
relative_positions_buckets,
|
246 |
+
aan_useffn,
|
247 |
+
full_context_alignment,
|
248 |
+
alignment_heads,
|
249 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
250 |
+
add_qkvbias=add_qkvbias,
|
251 |
+
num_kv=num_kv,
|
252 |
+
add_ffnbias=add_ffnbias,
|
253 |
+
parallel_residual=parallel_residual,
|
254 |
+
shared_layer_norm=shared_layer_norm,
|
255 |
+
layer_norm=layer_norm,
|
256 |
+
norm_eps=norm_eps,
|
257 |
+
use_ckpting=use_ckpting,
|
258 |
+
parallel_gpu=parallel_gpu,
|
259 |
+
)
|
260 |
+
self.context_attn = MultiHeadedAttention(
|
261 |
+
heads,
|
262 |
+
d_model,
|
263 |
+
dropout=attention_dropout,
|
264 |
+
attn_type="context",
|
265 |
+
add_qkvbias=add_qkvbias,
|
266 |
+
num_kv=num_kv,
|
267 |
+
use_ckpting=use_ckpting,
|
268 |
+
parallel_gpu=parallel_gpu,
|
269 |
+
)
|
270 |
+
if layer_norm == "standard":
|
271 |
+
self.layer_norm_2 = nn.LayerNorm(d_model, eps=norm_eps)
|
272 |
+
elif layer_norm == "rms":
|
273 |
+
self.layer_norm_2 = RMSNorm(d_model, eps=norm_eps)
|
274 |
+
else:
|
275 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
276 |
+
|
277 |
+
def update_dropout(self, dropout, attention_dropout):
|
278 |
+
super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout)
|
279 |
+
self.context_attn.update_dropout(attention_dropout)
|
280 |
+
|
281 |
+
def _forward(
|
282 |
+
self,
|
283 |
+
layer_in,
|
284 |
+
enc_out,
|
285 |
+
src_pad_mask,
|
286 |
+
tgt_pad_mask,
|
287 |
+
step=None,
|
288 |
+
future=False,
|
289 |
+
return_attn=False,
|
290 |
+
):
|
291 |
+
"""A naive forward pass for transformer decoder.
|
292 |
+
|
293 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
294 |
+
|
295 |
+
Args:
|
296 |
+
layer_in (FloatTensor): ``(batch_size, T, model_dim)``
|
297 |
+
enc_out (FloatTensor): ``(batch_size, src_len, model_dim)``
|
298 |
+
src_pad_mask (bool): ``(batch_size, 1, src_len)``
|
299 |
+
tgt_pad_mask (bool): ``(batch_size, 1, T)``
|
300 |
+
step (int or None): stepwise decoding counter
|
301 |
+
future (bool): If set True, do not apply future_mask.
|
302 |
+
return_attn (bool) : if set True requires attns output
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
(FloatTensor, FloatTensor):
|
306 |
+
|
307 |
+
* layer_out ``(batch_size, T, model_dim)``
|
308 |
+
* attns ``(batch_size, head, T, src_len)``
|
309 |
+
|
310 |
+
"""
|
311 |
+
dec_mask = None
|
312 |
+
src_pad_mask = src_pad_mask.unsqueeze(1) # [B,1,1,slen]
|
313 |
+
|
314 |
+
if layer_in.size(1) > 1:
|
315 |
+
# masking is necessary when sequence length is greater than one
|
316 |
+
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
|
317 |
+
dec_mask = dec_mask.unsqueeze(1)
|
318 |
+
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
|
319 |
+
src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1)
|
320 |
+
# mask now are (batch x 1 x tlen x s or t len)
|
321 |
+
# 1 = heads to be expanded in MHA
|
322 |
+
|
323 |
+
norm_layer_in = self.layer_norm_1(layer_in)
|
324 |
+
|
325 |
+
self_attn, _ = self._forward_self_attn(norm_layer_in, dec_mask, step)
|
326 |
+
|
327 |
+
if self.parallel_residual:
|
328 |
+
ctx_attn, attns = self.context_attn(
|
329 |
+
enc_out,
|
330 |
+
enc_out,
|
331 |
+
norm_layer_in,
|
332 |
+
mask=src_pad_mask,
|
333 |
+
return_attn=return_attn,
|
334 |
+
)
|
335 |
+
# feed_forward applies residual, so we remove and apply residual with un-normed
|
336 |
+
layer_out = (
|
337 |
+
self.feed_forward(norm_layer_in)
|
338 |
+
- norm_layer_in
|
339 |
+
+ layer_in
|
340 |
+
+ self.dropout(self_attn)
|
341 |
+
+ ctx_attn
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
query = self.dropout(self_attn) + layer_in
|
345 |
+
norm_query = self.layer_norm_2(query)
|
346 |
+
ctx_attn, attns = self.context_attn(
|
347 |
+
enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
|
348 |
+
)
|
349 |
+
layer_out = self.feed_forward(self.dropout(ctx_attn) + query)
|
350 |
+
|
351 |
+
return layer_out, attns
|
352 |
+
|
353 |
+
|
354 |
+
class TransformerDecoderBase(DecoderBase):
|
355 |
+
def __init__(
|
356 |
+
self, d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
357 |
+
):
|
358 |
+
super(TransformerDecoderBase, self).__init__()
|
359 |
+
|
360 |
+
self.embeddings = embeddings
|
361 |
+
|
362 |
+
# Decoder State
|
363 |
+
self.state = {}
|
364 |
+
|
365 |
+
# previously, there was a GlobalAttention module here for copy
|
366 |
+
# attention. But it was never actually used -- the "copy" attention
|
367 |
+
# just reuses the context attention.
|
368 |
+
self._copy = copy_attn
|
369 |
+
if layer_norm == "standard":
|
370 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
|
371 |
+
elif layer_norm == "rms":
|
372 |
+
self.layer_norm = RMSNorm(d_model, eps=norm_eps)
|
373 |
+
else:
|
374 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
375 |
+
|
376 |
+
self.alignment_layer = alignment_layer
|
377 |
+
|
378 |
+
@classmethod
|
379 |
+
def from_opt(cls, opt, embeddings):
|
380 |
+
"""Alternate constructor."""
|
381 |
+
return cls(
|
382 |
+
opt.dec_layers,
|
383 |
+
opt.dec_hid_size,
|
384 |
+
opt.heads,
|
385 |
+
opt.transformer_ff,
|
386 |
+
opt.copy_attn,
|
387 |
+
opt.self_attn_type,
|
388 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
389 |
+
opt.attention_dropout[0]
|
390 |
+
if type(opt.attention_dropout) is list
|
391 |
+
else opt.attention_dropout,
|
392 |
+
embeddings,
|
393 |
+
opt.max_relative_positions,
|
394 |
+
opt.relative_positions_buckets,
|
395 |
+
opt.aan_useffn,
|
396 |
+
opt.full_context_alignment,
|
397 |
+
opt.alignment_layer,
|
398 |
+
alignment_heads=opt.alignment_heads,
|
399 |
+
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
|
400 |
+
add_qkvbias=opt.add_qkvbias,
|
401 |
+
num_kv=opt.num_kv,
|
402 |
+
add_ffnbias=opt.add_ffnbias,
|
403 |
+
parallel_residual=opt.parallel_residual,
|
404 |
+
shared_layer_norm=opt.shared_layer_norm,
|
405 |
+
layer_norm=opt.layer_norm,
|
406 |
+
norm_eps=opt.norm_eps,
|
407 |
+
use_ckpting=opt.use_ckpting,
|
408 |
+
parallel_gpu=opt.world_size
|
409 |
+
if opt.parallel_mode == "tensor_parallel"
|
410 |
+
else 1,
|
411 |
+
)
|
412 |
+
|
413 |
+
def init_state(self, src, enc_out, enc_final_hs):
|
414 |
+
"""Initialize decoder state."""
|
415 |
+
self.state["src"] = src
|
416 |
+
|
417 |
+
def map_state(self, fn):
|
418 |
+
if self.state["src"] is not None:
|
419 |
+
self.state["src"] = fn(self.state["src"], 0)
|
420 |
+
for layer in self.transformer_layers:
|
421 |
+
if hasattr(layer, "context_attn"):
|
422 |
+
if layer.context_attn.layer_cache[1]["keys"].numel() != 0:
|
423 |
+
x = fn(layer.context_attn.layer_cache[1]["keys"], 0)
|
424 |
+
y = fn(layer.context_attn.layer_cache[1]["values"], 0)
|
425 |
+
layer.context_attn.layer_cache = True, {"keys": x, "values": y}
|
426 |
+
if isinstance(layer.self_attn, AverageAttention):
|
427 |
+
if layer.self_attn.layer_cache[1]["prev_g"].numel() != 0:
|
428 |
+
x = fn(layer.self_attn.layer_cache[1]["prev_g"], 0)
|
429 |
+
layer.self_attn.layer_cache = True, {"prev_g": x}
|
430 |
+
else:
|
431 |
+
if layer.self_attn.layer_cache[1]["keys"].numel() != 0:
|
432 |
+
x = fn(layer.self_attn.layer_cache[1]["keys"], 0)
|
433 |
+
y = fn(layer.self_attn.layer_cache[1]["values"], 0)
|
434 |
+
layer.self_attn.layer_cache = True, {"keys": x, "values": y}
|
435 |
+
|
436 |
+
def detach_state(self):
|
437 |
+
raise NotImplementedError
|
438 |
+
|
439 |
+
def forward(self, *args, **kwargs):
|
440 |
+
raise NotImplementedError
|
441 |
+
|
442 |
+
def update_dropout(self, dropout, attention_dropout):
|
443 |
+
self.embeddings.update_dropout(dropout)
|
444 |
+
for layer in self.transformer_layers:
|
445 |
+
layer.update_dropout(dropout, attention_dropout)
|
446 |
+
|
447 |
+
|
448 |
+
class TransformerDecoder(TransformerDecoderBase):
|
449 |
+
"""The Transformer decoder from "Attention is All You Need".
|
450 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
451 |
+
|
452 |
+
Args:
|
453 |
+
num_layers (int): number of decoder layers.
|
454 |
+
d_model (int): size of the model
|
455 |
+
heads (int): number of heads
|
456 |
+
d_ff (int): size of the inner FF layer
|
457 |
+
copy_attn (bool): if using a separate copy attention
|
458 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
459 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
460 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
461 |
+
embeddings (onmt.modules.Embeddings):
|
462 |
+
embeddings to use, should have positional encodings
|
463 |
+
max_relative_positions (int):
|
464 |
+
Max distance between inputs in relative positions representations
|
465 |
+
relative_positions_buckets (int):
|
466 |
+
Number of buckets when using relative position bias
|
467 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
468 |
+
full_context_alignment (bool):
|
469 |
+
whether enable an extra full context decoder forward for alignment
|
470 |
+
alignment_layer (int): N° Layer to supervise with for alignment guiding
|
471 |
+
alignment_heads (int):
|
472 |
+
N. of cross attention heads to use for alignment guiding
|
473 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
474 |
+
layer_norm (string): type of layer normalization standard/rms
|
475 |
+
"""
|
476 |
+
|
477 |
+
def __init__(
|
478 |
+
self,
|
479 |
+
num_layers,
|
480 |
+
d_model,
|
481 |
+
heads,
|
482 |
+
d_ff,
|
483 |
+
copy_attn,
|
484 |
+
self_attn_type,
|
485 |
+
dropout,
|
486 |
+
attention_dropout,
|
487 |
+
embeddings,
|
488 |
+
max_relative_positions,
|
489 |
+
relative_positions_buckets,
|
490 |
+
aan_useffn,
|
491 |
+
full_context_alignment,
|
492 |
+
alignment_layer,
|
493 |
+
alignment_heads,
|
494 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
495 |
+
add_qkvbias=False,
|
496 |
+
num_kv=0,
|
497 |
+
add_ffnbias=True,
|
498 |
+
parallel_residual=False,
|
499 |
+
shared_layer_norm=False,
|
500 |
+
layer_norm="standard",
|
501 |
+
norm_eps=1e-6,
|
502 |
+
use_ckpting=[],
|
503 |
+
parallel_gpu=1,
|
504 |
+
):
|
505 |
+
super(TransformerDecoder, self).__init__(
|
506 |
+
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
507 |
+
)
|
508 |
+
|
509 |
+
self.transformer_layers = nn.ModuleList(
|
510 |
+
[
|
511 |
+
TransformerDecoderLayer(
|
512 |
+
d_model,
|
513 |
+
heads,
|
514 |
+
d_ff,
|
515 |
+
dropout,
|
516 |
+
attention_dropout,
|
517 |
+
self_attn_type=self_attn_type,
|
518 |
+
max_relative_positions=max_relative_positions,
|
519 |
+
relative_positions_buckets=relative_positions_buckets,
|
520 |
+
aan_useffn=aan_useffn,
|
521 |
+
full_context_alignment=full_context_alignment,
|
522 |
+
alignment_heads=alignment_heads,
|
523 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
524 |
+
add_qkvbias=add_qkvbias,
|
525 |
+
num_kv=num_kv,
|
526 |
+
add_ffnbias=add_ffnbias,
|
527 |
+
parallel_residual=parallel_residual,
|
528 |
+
shared_layer_norm=shared_layer_norm,
|
529 |
+
layer_norm=layer_norm,
|
530 |
+
norm_eps=norm_eps,
|
531 |
+
use_ckpting=use_ckpting,
|
532 |
+
parallel_gpu=parallel_gpu,
|
533 |
+
)
|
534 |
+
for i in range(num_layers)
|
535 |
+
]
|
536 |
+
)
|
537 |
+
|
538 |
+
def detach_state(self):
|
539 |
+
self.state["src"] = self.state["src"].detach()
|
540 |
+
|
541 |
+
def forward(self, tgt, enc_out=None, step=None, **kwargs):
|
542 |
+
"""
|
543 |
+
Decode, possibly stepwise.
|
544 |
+
when training step is always None, when decoding, step increases
|
545 |
+
tgt (Tensor): batch x tlen x feats
|
546 |
+
enc_out (Tensor): encoder output (batch x slen x model_dim)
|
547 |
+
"""
|
548 |
+
if enc_out is None:
|
549 |
+
enc_out = self.embeddings(tgt)
|
550 |
+
if step == 0:
|
551 |
+
self._init_cache(enc_out)
|
552 |
+
elif step is None:
|
553 |
+
for layer in self.transformer_layers:
|
554 |
+
if isinstance(layer.self_attn, AverageAttention):
|
555 |
+
layer.self_attn.layer_cache = False, {"prev_g": torch.tensor([])}
|
556 |
+
else:
|
557 |
+
layer.self_attn.layer_cache = (
|
558 |
+
False,
|
559 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
560 |
+
)
|
561 |
+
layer.context_attn.layer_cache = (
|
562 |
+
False,
|
563 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
564 |
+
)
|
565 |
+
|
566 |
+
emb = self.embeddings(tgt, step=step)
|
567 |
+
dec_out = emb
|
568 |
+
assert emb.dim() == 3 # len x batch x embedding_dim
|
569 |
+
|
570 |
+
pad_idx = self.embeddings.word_padding_idx
|
571 |
+
src_lens = kwargs["src_len"]
|
572 |
+
src_max_len = self.state["src"].shape[1]
|
573 |
+
src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen]
|
574 |
+
src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen]
|
575 |
+
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
576 |
+
|
577 |
+
with_align = kwargs.pop("with_align", False)
|
578 |
+
return_attn = with_align or self._copy
|
579 |
+
attn_aligns = []
|
580 |
+
|
581 |
+
for layer in self.transformer_layers:
|
582 |
+
dec_out, attn, attn_align = layer(
|
583 |
+
dec_out,
|
584 |
+
enc_out,
|
585 |
+
src_pad_mask,
|
586 |
+
tgt_pad_mask,
|
587 |
+
step=step,
|
588 |
+
with_align=with_align,
|
589 |
+
return_attn=return_attn,
|
590 |
+
)
|
591 |
+
if attn_align is not None:
|
592 |
+
attn_aligns.append(attn_align)
|
593 |
+
|
594 |
+
dec_out = self.layer_norm(dec_out)
|
595 |
+
|
596 |
+
attns = {"std": attn}
|
597 |
+
if self._copy:
|
598 |
+
attns["copy"] = attn
|
599 |
+
if with_align:
|
600 |
+
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
601 |
+
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
602 |
+
|
603 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
604 |
+
return dec_out, attns
|
605 |
+
|
606 |
+
def _init_cache(self, enc_out):
|
607 |
+
batch_size = enc_out.size(0)
|
608 |
+
depth = enc_out.size(-1)
|
609 |
+
|
610 |
+
for layer in self.transformer_layers:
|
611 |
+
# first value set to True triggered by the beginning of decoding
|
612 |
+
# layer_cache becomes active in the MultiHeadedAttention fwd
|
613 |
+
layer.context_attn.layer_cache = (
|
614 |
+
True,
|
615 |
+
{
|
616 |
+
"keys": torch.tensor([], device=enc_out.device),
|
617 |
+
"values": torch.tensor([], device=enc_out.device),
|
618 |
+
},
|
619 |
+
)
|
620 |
+
if isinstance(layer.self_attn, AverageAttention):
|
621 |
+
layer.self_attn.layer_cache = True, {
|
622 |
+
"prev_g": torch.zeros(
|
623 |
+
(batch_size, 1, depth), device=enc_out.device
|
624 |
+
).to(enc_out.dtype)
|
625 |
+
}
|
626 |
+
else:
|
627 |
+
layer.self_attn.layer_cache = (
|
628 |
+
True,
|
629 |
+
{
|
630 |
+
"keys": torch.tensor([], device=enc_out.device),
|
631 |
+
"values": torch.tensor([], device=enc_out.device),
|
632 |
+
},
|
633 |
+
)
|
634 |
+
|
635 |
+
|
636 |
+
class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
|
637 |
+
"""Transformer Decoder only layer block in GPT style.
|
638 |
+
Args:
|
639 |
+
See TransformerDecoderLayerBase
|
640 |
+
"""
|
641 |
+
|
642 |
+
def _forward(
|
643 |
+
self, layer_in, tgt_pad_mask, step=None, future=False, return_attn=False
|
644 |
+
):
|
645 |
+
"""A naive forward pass for transformer decoder.
|
646 |
+
|
647 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
648 |
+
|
649 |
+
Args:
|
650 |
+
layer_in (FloatTensor): ``(batch_size, T, model_dim)``
|
651 |
+
tgt_pad_mask (bool): ``(batch_size, 1, T)``
|
652 |
+
layer_cache (dict or None): cached layer info when stepwise decode
|
653 |
+
step (int or None): stepwise decoding counter
|
654 |
+
future (bool): If set True, do not apply future_mask.
|
655 |
+
return_attn (bool): If set True return attn
|
656 |
+
|
657 |
+
Returns:
|
658 |
+
(FloatTensor, FloatTensor):
|
659 |
+
|
660 |
+
* layer_out ``(batch_size, T, model_dim)``
|
661 |
+
* attns ``(batch_size, head, T, T)``
|
662 |
+
|
663 |
+
"""
|
664 |
+
dec_mask = None
|
665 |
+
|
666 |
+
if layer_in.size(1) > 1:
|
667 |
+
# masking is necessary when sequence length is greater than one
|
668 |
+
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
|
669 |
+
dec_mask = dec_mask.unsqueeze(1)
|
670 |
+
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
|
671 |
+
# mask now are (batch x 1 x tlen x tlen)
|
672 |
+
# 1 = heads to be expanded in MHA
|
673 |
+
|
674 |
+
norm_layer_in = self.layer_norm_1(layer_in)
|
675 |
+
|
676 |
+
attn_output, attns = self._forward_self_attn(
|
677 |
+
norm_layer_in, dec_mask, step, return_attn=return_attn
|
678 |
+
)
|
679 |
+
|
680 |
+
if self.parallel_residual:
|
681 |
+
# feed_forward applies residual, so we remove and apply residual with un-normed
|
682 |
+
if not self.shared_layer_norm:
|
683 |
+
norm_res_layer_in = self.layer_norm_res(layer_in)
|
684 |
+
ff_in = norm_res_layer_in
|
685 |
+
else:
|
686 |
+
ff_in = norm_layer_in
|
687 |
+
layer_out = (
|
688 |
+
self.feed_forward(ff_in) - ff_in + layer_in + self.dropout(attn_output)
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
layer_out = self.dropout(attn_output) + layer_in
|
692 |
+
layer_out = self.feed_forward(layer_out)
|
693 |
+
|
694 |
+
return layer_out, attns
|
695 |
+
|
696 |
+
|
697 |
+
class TransformerLMDecoder(TransformerDecoderBase):
|
698 |
+
"""The Transformer decoder from GPT-2
|
699 |
+
Args:
|
700 |
+
num_layers (int): number of decoder layers.
|
701 |
+
d_model (int): size of the model
|
702 |
+
heads (int): number of heads
|
703 |
+
d_ff (int): size of the inner FF layer
|
704 |
+
copy_attn (bool): if using a separate copy attention
|
705 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
706 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
707 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
708 |
+
embeddings (onmt.modules.Embeddings):
|
709 |
+
embeddings to use, should have positional encodings
|
710 |
+
max_relative_positions (int):
|
711 |
+
Max distance between inputs in relative positions representations
|
712 |
+
relative_positions_buckets (int):
|
713 |
+
Number of buckets when using Relative positions bias
|
714 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
715 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
716 |
+
"""
|
717 |
+
|
718 |
+
def __init__(
|
719 |
+
self,
|
720 |
+
num_layers,
|
721 |
+
d_model,
|
722 |
+
heads,
|
723 |
+
d_ff,
|
724 |
+
copy_attn,
|
725 |
+
self_attn_type,
|
726 |
+
dropout,
|
727 |
+
attention_dropout,
|
728 |
+
embeddings,
|
729 |
+
max_relative_positions,
|
730 |
+
relative_positions_buckets,
|
731 |
+
aan_useffn,
|
732 |
+
full_context_alignment=None,
|
733 |
+
alignment_layer=None,
|
734 |
+
alignment_heads=None,
|
735 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
736 |
+
add_qkvbias=False,
|
737 |
+
num_kv=0,
|
738 |
+
add_ffnbias=True,
|
739 |
+
parallel_residual=False,
|
740 |
+
shared_layer_norm=False,
|
741 |
+
layer_norm="standard",
|
742 |
+
norm_eps=1e-6,
|
743 |
+
use_ckpting=[],
|
744 |
+
parallel_gpu=1,
|
745 |
+
):
|
746 |
+
super(TransformerLMDecoder, self).__init__(
|
747 |
+
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
748 |
+
)
|
749 |
+
self.transformer_layers = nn.ModuleList(
|
750 |
+
[
|
751 |
+
TransformerLMDecoderLayer(
|
752 |
+
d_model,
|
753 |
+
heads,
|
754 |
+
d_ff,
|
755 |
+
dropout,
|
756 |
+
attention_dropout,
|
757 |
+
self_attn_type=self_attn_type,
|
758 |
+
max_relative_positions=max_relative_positions,
|
759 |
+
relative_positions_buckets=relative_positions_buckets,
|
760 |
+
aan_useffn=aan_useffn,
|
761 |
+
full_context_alignment=None,
|
762 |
+
alignment_heads=None,
|
763 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
764 |
+
add_qkvbias=add_qkvbias,
|
765 |
+
num_kv=num_kv,
|
766 |
+
add_ffnbias=add_ffnbias,
|
767 |
+
parallel_residual=parallel_residual,
|
768 |
+
shared_layer_norm=shared_layer_norm,
|
769 |
+
layer_norm=layer_norm,
|
770 |
+
norm_eps=norm_eps,
|
771 |
+
use_ckpting=use_ckpting,
|
772 |
+
parallel_gpu=parallel_gpu,
|
773 |
+
)
|
774 |
+
for i in range(num_layers)
|
775 |
+
]
|
776 |
+
)
|
777 |
+
|
778 |
+
def init_state(self, src=None, enc_out=None, enc_final_hs=None):
|
779 |
+
super(TransformerLMDecoder, self).init_state(None, None, None)
|
780 |
+
|
781 |
+
def detach_state(self):
|
782 |
+
pass
|
783 |
+
|
784 |
+
def forward(self, tgt, enc_out=None, step=None, **kwargs):
|
785 |
+
"""Decode, possibly stepwise."""
|
786 |
+
if step == 0:
|
787 |
+
self._init_cache(tgt)
|
788 |
+
elif step is None:
|
789 |
+
for layer in self.transformer_layers:
|
790 |
+
layer.self_attn.layer_cache = (
|
791 |
+
False,
|
792 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
793 |
+
)
|
794 |
+
|
795 |
+
dec_out = self.embeddings(tgt, step=step)
|
796 |
+
|
797 |
+
assert dec_out.dim() == 3 # batch x len x embedding_dim
|
798 |
+
|
799 |
+
pad_idx = self.embeddings.word_padding_idx
|
800 |
+
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
801 |
+
|
802 |
+
with_align = kwargs.pop("with_align", False)
|
803 |
+
return_attn = with_align or self._copy
|
804 |
+
assert not with_align, "TransformerLMDecoder does not support align"
|
805 |
+
|
806 |
+
for layer in self.transformer_layers:
|
807 |
+
dec_out, attn, _ = layer(
|
808 |
+
dec_out,
|
809 |
+
tgt_pad_mask,
|
810 |
+
step=step,
|
811 |
+
with_align=with_align,
|
812 |
+
return_attn=return_attn,
|
813 |
+
)
|
814 |
+
|
815 |
+
dec_out = self.layer_norm(dec_out)
|
816 |
+
|
817 |
+
attns = {"std": attn}
|
818 |
+
if self._copy:
|
819 |
+
attns["copy"] = attn
|
820 |
+
|
821 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
822 |
+
return dec_out, attns
|
823 |
+
|
824 |
+
def _init_cache(self, tgt=None):
|
825 |
+
for layer in self.transformer_layers:
|
826 |
+
if isinstance(layer.self_attn, AverageAttention):
|
827 |
+
raise NotImplementedError
|
828 |
+
else:
|
829 |
+
layer.self_attn.layer_cache = (
|
830 |
+
True,
|
831 |
+
{
|
832 |
+
"keys": torch.tensor([], device=tgt.device),
|
833 |
+
"values": torch.tensor([], device=tgt.device),
|
834 |
+
},
|
835 |
+
)
|
onmt/encoders/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module defining encoders."""
|
2 |
+
import os
|
3 |
+
import importlib
|
4 |
+
from onmt.encoders.encoder import EncoderBase
|
5 |
+
from onmt.encoders.transformer import TransformerEncoder
|
6 |
+
from onmt.encoders.ggnn_encoder import GGNNEncoder
|
7 |
+
from onmt.encoders.rnn_encoder import RNNEncoder
|
8 |
+
from onmt.encoders.cnn_encoder import CNNEncoder
|
9 |
+
from onmt.encoders.mean_encoder import MeanEncoder
|
10 |
+
|
11 |
+
|
12 |
+
str2enc = {
|
13 |
+
"ggnn": GGNNEncoder,
|
14 |
+
"rnn": RNNEncoder,
|
15 |
+
"brnn": RNNEncoder,
|
16 |
+
"cnn": CNNEncoder,
|
17 |
+
"transformer": TransformerEncoder,
|
18 |
+
"mean": MeanEncoder,
|
19 |
+
}
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"EncoderBase",
|
23 |
+
"TransformerEncoder",
|
24 |
+
"GGNNEncoder",
|
25 |
+
"RNNEncoder",
|
26 |
+
"CNNEncoder",
|
27 |
+
"MeanEncoder",
|
28 |
+
"str2enc",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
def get_encoders_cls(encoder_names):
|
33 |
+
"""Return valid encoder class indicated in `encoder_names`."""
|
34 |
+
encoders_cls = {}
|
35 |
+
for name in encoder_names:
|
36 |
+
if name not in str2enc:
|
37 |
+
raise ValueError("%s encoder not supported!" % name)
|
38 |
+
encoders_cls[name] = str2enc[name]
|
39 |
+
return encoders_cls
|
40 |
+
|
41 |
+
|
42 |
+
def register_encoder(name):
|
43 |
+
"""Encoder register that can be used to add new encoder class."""
|
44 |
+
|
45 |
+
def register_encoder_cls(cls):
|
46 |
+
if name in str2enc:
|
47 |
+
raise ValueError("Cannot register duplicate encoder ({})".format(name))
|
48 |
+
if not issubclass(cls, EncoderBase):
|
49 |
+
raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
|
50 |
+
str2enc[name] = cls
|
51 |
+
__all__.append(cls.__name__) # added to be complete
|
52 |
+
return cls
|
53 |
+
|
54 |
+
return register_encoder_cls
|
55 |
+
|
56 |
+
|
57 |
+
# Auto import python files in this directory
|
58 |
+
encoder_dir = os.path.dirname(__file__)
|
59 |
+
for file in os.listdir(encoder_dir):
|
60 |
+
path = os.path.join(encoder_dir, file)
|
61 |
+
if (
|
62 |
+
not file.startswith("_")
|
63 |
+
and not file.startswith(".")
|
64 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
65 |
+
):
|
66 |
+
file_name = file[: file.find(".py")] if file.endswith(".py") else file
|
67 |
+
module = importlib.import_module("onmt.encoders." + file_name)
|
onmt/encoders/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (3.13 kB). View file
|
|