Oopstom commited on
Commit
c668e80
·
verified ·
1 Parent(s): 9c9b678

Upload 313 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. e_smiles.py +0 -0
  2. infer.sh +10 -0
  3. inference.py +5 -0
  4. onmt/__init__.py +24 -0
  5. onmt/__pycache__/__init__.cpython-311.pyc +0 -0
  6. onmt/__pycache__/__init__.cpython-37.pyc +0 -0
  7. onmt/__pycache__/__init__.cpython-38.pyc +0 -0
  8. onmt/__pycache__/constants.cpython-311.pyc +0 -0
  9. onmt/__pycache__/constants.cpython-38.pyc +0 -0
  10. onmt/__pycache__/inference_engine.cpython-38.pyc +0 -0
  11. onmt/__pycache__/model_builder.cpython-311.pyc +0 -0
  12. onmt/__pycache__/model_builder.cpython-38.pyc +0 -0
  13. onmt/__pycache__/opts.cpython-311.pyc +0 -0
  14. onmt/__pycache__/opts.cpython-38.pyc +0 -0
  15. onmt/__pycache__/train_single.cpython-38.pyc +0 -0
  16. onmt/__pycache__/trainer.cpython-38.pyc +0 -0
  17. onmt/bin/__init__.py +0 -0
  18. onmt/bin/__pycache__/__init__.cpython-311.pyc +0 -0
  19. onmt/bin/__pycache__/__init__.cpython-38.pyc +0 -0
  20. onmt/bin/__pycache__/average_models.cpython-38.pyc +0 -0
  21. onmt/bin/__pycache__/build_vocab.cpython-38.pyc +0 -0
  22. onmt/bin/__pycache__/release_model.cpython-38.pyc +0 -0
  23. onmt/bin/__pycache__/server.cpython-38.pyc +0 -0
  24. onmt/bin/__pycache__/train.cpython-38.pyc +0 -0
  25. onmt/bin/__pycache__/translate.cpython-311.pyc +0 -0
  26. onmt/bin/__pycache__/translate.cpython-38.pyc +0 -0
  27. onmt/bin/average_models.py +60 -0
  28. onmt/bin/build_vocab.py +287 -0
  29. onmt/bin/release_model.py +39 -0
  30. onmt/bin/server.py +167 -0
  31. onmt/bin/train.py +71 -0
  32. onmt/bin/translate.py +60 -0
  33. onmt/constants.py +41 -0
  34. onmt/decoders/__init__.py +63 -0
  35. onmt/decoders/__pycache__/__init__.cpython-311.pyc +0 -0
  36. onmt/decoders/__pycache__/__init__.cpython-38.pyc +0 -0
  37. onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc +0 -0
  38. onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc +0 -0
  39. onmt/decoders/__pycache__/decoder.cpython-311.pyc +0 -0
  40. onmt/decoders/__pycache__/decoder.cpython-38.pyc +0 -0
  41. onmt/decoders/__pycache__/ensemble.cpython-311.pyc +0 -0
  42. onmt/decoders/__pycache__/ensemble.cpython-38.pyc +0 -0
  43. onmt/decoders/__pycache__/transformer.cpython-311.pyc +0 -0
  44. onmt/decoders/__pycache__/transformer.cpython-38.pyc +0 -0
  45. onmt/decoders/cnn_decoder.py +141 -0
  46. onmt/decoders/decoder.py +405 -0
  47. onmt/decoders/ensemble.py +150 -0
  48. onmt/decoders/transformer.py +835 -0
  49. onmt/encoders/__init__.py +67 -0
  50. onmt/encoders/__pycache__/__init__.cpython-311.pyc +0 -0
e_smiles.py ADDED
The diff for this file is too large to render. See raw diff
 
infer.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ python inference.py \
2
+ --model trained_models/retrosnyhesis_ReactSeq_prompt_model_on_50k_aug100.pt \
3
+ --src ./tmp_data/src.txt \
4
+ --output ./tmp_data/tgt.txt \
5
+ --beam_size 10 \
6
+ --n_best 10 \
7
+ --batch_size 16384 \
8
+ --batch_type tokens \
9
+ --max_length 500 \
10
+ --seed 0
inference.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from onmt.bin.translate import main
3
+
4
+ if __name__ == "__main__":
5
+ main()
onmt/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Main entry point of the ONMT library """
2
+ import onmt.inputters
3
+ import onmt.encoders
4
+ import onmt.decoders
5
+ import onmt.models
6
+ import onmt.utils
7
+ import onmt.modules
8
+ import sys
9
+ import onmt.utils.optimizers
10
+
11
+ onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer
12
+ sys.modules["onmt.Optim"] = onmt.utils.optimizers
13
+
14
+ # For Flake
15
+ __all__ = [
16
+ onmt.inputters,
17
+ onmt.encoders,
18
+ onmt.decoders,
19
+ onmt.models,
20
+ onmt.utils,
21
+ onmt.modules,
22
+ ]
23
+
24
+ __version__ = "3.4.1"
onmt/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (892 Bytes). View file
 
onmt/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (605 Bytes). View file
 
onmt/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (603 Bytes). View file
 
onmt/__pycache__/constants.cpython-311.pyc ADDED
Binary file (2.06 kB). View file
 
onmt/__pycache__/constants.cpython-38.pyc ADDED
Binary file (1.61 kB). View file
 
onmt/__pycache__/inference_engine.cpython-38.pyc ADDED
Binary file (3.22 kB). View file
 
onmt/__pycache__/model_builder.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
onmt/__pycache__/model_builder.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
onmt/__pycache__/opts.cpython-311.pyc ADDED
Binary file (58 kB). View file
 
onmt/__pycache__/opts.cpython-38.pyc ADDED
Binary file (38.4 kB). View file
 
onmt/__pycache__/train_single.cpython-38.pyc ADDED
Binary file (6.41 kB). View file
 
onmt/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
onmt/bin/__init__.py ADDED
File without changes
onmt/bin/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
onmt/bin/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (145 Bytes). View file
 
onmt/bin/__pycache__/average_models.cpython-38.pyc ADDED
Binary file (1.48 kB). View file
 
onmt/bin/__pycache__/build_vocab.cpython-38.pyc ADDED
Binary file (8.77 kB). View file
 
onmt/bin/__pycache__/release_model.cpython-38.pyc ADDED
Binary file (1.17 kB). View file
 
onmt/bin/__pycache__/server.cpython-38.pyc ADDED
Binary file (5.08 kB). View file
 
onmt/bin/__pycache__/train.cpython-38.pyc ADDED
Binary file (1.84 kB). View file
 
onmt/bin/__pycache__/translate.cpython-311.pyc ADDED
Binary file (2.89 kB). View file
 
onmt/bin/__pycache__/translate.cpython-38.pyc ADDED
Binary file (1.77 kB). View file
 
onmt/bin/average_models.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def average_models(model_files, fp32=False):
7
+ vocab = None
8
+ opt = None
9
+ avg_model = None
10
+ avg_generator = None
11
+
12
+ for i, model_file in enumerate(model_files):
13
+ m = torch.load(model_file, map_location="cpu")
14
+ model_weights = m["model"]
15
+ generator_weights = m["generator"]
16
+
17
+ if fp32:
18
+ for k, v in model_weights.items():
19
+ model_weights[k] = v.float()
20
+ for k, v in generator_weights.items():
21
+ generator_weights[k] = v.float()
22
+
23
+ if i == 0:
24
+ vocab, opt = m["vocab"], m["opt"]
25
+ avg_model = model_weights
26
+ avg_generator = generator_weights
27
+ else:
28
+ for k, v in avg_model.items():
29
+ avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
30
+
31
+ for k, v in avg_generator.items():
32
+ avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
33
+
34
+ final = {
35
+ "vocab": vocab,
36
+ "opt": opt,
37
+ "optim": None,
38
+ "generator": avg_generator,
39
+ "model": avg_model,
40
+ }
41
+ return final
42
+
43
+
44
+ def main():
45
+ parser = argparse.ArgumentParser(description="")
46
+ parser.add_argument(
47
+ "-models", "-m", nargs="+", required=True, help="List of models"
48
+ )
49
+ parser.add_argument("-output", "-o", required=True, help="Output file")
50
+ parser.add_argument(
51
+ "-fp32", "-f", action="store_true", help="Cast params to float32"
52
+ )
53
+ opt = parser.parse_args()
54
+
55
+ final = average_models(opt.models, opt.fp32)
56
+ torch.save(final, opt.output)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
onmt/bin/build_vocab.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Get vocabulary coutings from transformed corpora samples."""
3
+ import os
4
+ import copy
5
+ import multiprocessing as mp
6
+ import pyonmttok
7
+ from functools import partial
8
+ from onmt.utils.logging import init_logger, logger
9
+ from onmt.utils.misc import set_random_seed, check_path
10
+ from onmt.utils.parse import ArgumentParser
11
+ from onmt.opts import dynamic_prepare_opts
12
+ from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
13
+ from onmt.inputters.text_utils import process, append_features_to_text
14
+ from onmt.transforms import make_transforms, get_transforms_cls
15
+ from onmt.constants import CorpusName, CorpusTask
16
+ from collections import Counter
17
+
18
+
19
+ MAXBUCKETSIZE = 256000
20
+
21
+
22
+ def write_files_from_queues(sample_path, queues):
23
+ """
24
+ Standalone process that reads data from
25
+ queues in order and write to sample files.
26
+ """
27
+ os.makedirs(sample_path, exist_ok=True)
28
+ for c_name in queues.keys():
29
+ dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
30
+ with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open(
31
+ dest_base + ".tgt", "w", encoding="utf-8"
32
+ ) as f_tgt:
33
+ while True:
34
+ _next = False
35
+ for q in queues[c_name]:
36
+ item = q.get()
37
+ if item == "blank":
38
+ continue
39
+ if item == "break":
40
+ _next = True
41
+ break
42
+ _, src_line, tgt_line = item
43
+ f_src.write(src_line + "\n")
44
+ f_tgt.write(tgt_line + "\n")
45
+ if _next:
46
+ break
47
+
48
+
49
+ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
50
+ """Build vocab on (strided) subpart of the data."""
51
+ sub_counter_src = Counter()
52
+ sub_counter_tgt = Counter()
53
+ sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
54
+ datasets_iterables = build_corpora_iters(
55
+ corpora,
56
+ transforms,
57
+ opts.data,
58
+ skip_empty_level=opts.skip_empty_level,
59
+ stride=stride,
60
+ offset=offset,
61
+ )
62
+ for c_name, c_iter in datasets_iterables.items():
63
+ for i, item in enumerate(c_iter):
64
+ maybe_example = process(CorpusTask.TRAIN, [item])
65
+ if maybe_example is not None:
66
+ maybe_example = maybe_example[0]
67
+ else:
68
+ if opts.dump_samples:
69
+ build_sub_vocab.queues[c_name][offset].put("blank")
70
+ continue
71
+ src_line, tgt_line = (
72
+ maybe_example["src"]["src"],
73
+ maybe_example["tgt"]["tgt"],
74
+ )
75
+ sub_counter_src.update(src_line.split(" "))
76
+ sub_counter_tgt.update(tgt_line.split(" "))
77
+
78
+ if "feats" in maybe_example["src"]:
79
+ src_feats_lines = maybe_example["src"]["feats"]
80
+ for k in range(opts.n_src_feats):
81
+ sub_counter_src_feats[k].update(src_feats_lines[k].split(" "))
82
+ else:
83
+ src_feats_lines = []
84
+
85
+ if opts.dump_samples:
86
+ src_pretty_line = append_features_to_text(src_line, src_feats_lines)
87
+ build_sub_vocab.queues[c_name][offset].put(
88
+ (i, src_pretty_line, tgt_line)
89
+ )
90
+ if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample:
91
+ if opts.dump_samples:
92
+ build_sub_vocab.queues[c_name][offset].put("break")
93
+ break
94
+ if opts.dump_samples:
95
+ build_sub_vocab.queues[c_name][offset].put("break")
96
+ return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
97
+
98
+
99
+ def init_pool(queues):
100
+ """Add the queues as attribute of the pooled function."""
101
+ build_sub_vocab.queues = queues
102
+
103
+
104
+ def build_vocab(opts, transforms, n_sample=3):
105
+ """Build vocabulary from data."""
106
+
107
+ if n_sample == -1:
108
+ logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
109
+ elif n_sample > 0:
110
+ logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
111
+ else:
112
+ raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")
113
+
114
+ if opts.dump_samples:
115
+ logger.info(
116
+ "The samples on which the vocab is built will be "
117
+ "dumped to disk. It may slow down the process."
118
+ )
119
+ corpora = get_corpora(opts, task=CorpusTask.TRAIN)
120
+ counter_src = Counter()
121
+ counter_tgt = Counter()
122
+ counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
123
+
124
+ queues = {
125
+ c_name: [
126
+ mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)
127
+ ]
128
+ for c_name in corpora.keys()
129
+ }
130
+ sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE)
131
+ if opts.dump_samples:
132
+ write_process = mp.Process(
133
+ target=write_files_from_queues, args=(sample_path, queues), daemon=True
134
+ )
135
+ write_process.start()
136
+ with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
137
+ func = partial(
138
+ build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads
139
+ )
140
+ for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
141
+ func, range(0, opts.num_threads)
142
+ ):
143
+ counter_src.update(sub_counter_src)
144
+ counter_tgt.update(sub_counter_tgt)
145
+ for i in range(opts.n_src_feats):
146
+ counter_src_feats[i].update(sub_counter_src_feats[i])
147
+ if opts.dump_samples:
148
+ write_process.join()
149
+ return counter_src, counter_tgt, counter_src_feats
150
+
151
+
152
+ def ingest_tokens(opts, transforms, n_sample, learner, stride, offset):
153
+ def _mp_ingest(data):
154
+ func = partial(process, CorpusName.TRAIN)
155
+ chunk = len(data) // opts.num_threads
156
+ with mp.Pool(opts.num_threads) as pool:
157
+ buckets = pool.map(
158
+ func,
159
+ [data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)],
160
+ )
161
+ for bucket in buckets:
162
+ for ex in bucket:
163
+ if ex is not None:
164
+ src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"])
165
+ learner.ingest(src_line)
166
+ learner.ingest(tgt_line)
167
+
168
+ corpora = get_corpora(opts, task=CorpusTask.TRAIN)
169
+ datasets_iterables = build_corpora_iters(
170
+ corpora,
171
+ transforms,
172
+ opts.data,
173
+ skip_empty_level=opts.skip_empty_level,
174
+ stride=stride,
175
+ offset=offset,
176
+ )
177
+ to_ingest = []
178
+ for c_name, c_iter in datasets_iterables.items():
179
+ for i, item in enumerate(c_iter):
180
+ if n_sample >= 0 and i >= n_sample:
181
+ break
182
+ if len(to_ingest) >= MAXBUCKETSIZE:
183
+ _mp_ingest(to_ingest)
184
+ to_ingest = []
185
+ to_ingest.append(item)
186
+ _mp_ingest(to_ingest)
187
+
188
+
189
+ def make_learner(tokenization_type, symbols):
190
+ if tokenization_type == "bpe":
191
+ # BPE training
192
+ learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols)
193
+ elif tokenization_type == "sentencepiece":
194
+ # SentencePiece training
195
+ learner = pyonmttok.SentencePieceLearner(
196
+ vocab_size=symbols, character_coverage=0.98
197
+ )
198
+ return learner
199
+
200
+
201
+ def build_vocab_main(opts):
202
+ """Apply transforms to samples of specified data and build vocab from it.
203
+
204
+ Transforms that need vocab will be disabled in this.
205
+ Built vocab is saved in plain text format as following and can be pass as
206
+ `-src_vocab` (and `-tgt_vocab`) when training:
207
+ ```
208
+ <tok_0>\t<count_0>
209
+ <tok_1>\t<count_1>
210
+ ```
211
+ """
212
+
213
+ ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
214
+ assert (
215
+ opts.n_sample == -1 or opts.n_sample > 1
216
+ ), f"Illegal argument n_sample={opts.n_sample}."
217
+
218
+ logger = init_logger()
219
+ set_random_seed(opts.seed, False)
220
+ transforms_cls = get_transforms_cls(opts._all_transform)
221
+
222
+ if opts.learn_subwords:
223
+ logger.info(f"Ingesting {opts.src_subword_type} model from corpus")
224
+ learner = make_learner(opts.src_subword_type, opts.learn_subwords_size)
225
+ if opts.src_subword_model is not None:
226
+ tok_path = opts.src_subword_model
227
+ else:
228
+ data_dir = os.path.split(opts.save_data)[0]
229
+ if not os.path.exists(data_dir):
230
+ os.makedirs(data_dir)
231
+ tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model")
232
+ save_opts = copy.deepcopy(opts)
233
+ opts.src_subword_type = "none"
234
+ opts.tgt_subword_type = "none"
235
+ opts.src_onmttok_kwargs["joiner_annotate"] = False
236
+ opts.tgt_onmttok_kwargs["joiner_annotate"] = False
237
+ transforms = make_transforms(opts, transforms_cls, None)
238
+ ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0)
239
+ logger.info(f"Learning {tok_path} model, patience")
240
+ learner.learn(tok_path)
241
+ opts = save_opts
242
+
243
+ transforms = make_transforms(opts, transforms_cls, None)
244
+
245
+ logger.info(f"Counter vocab from {opts.n_sample} samples.")
246
+ src_counter, tgt_counter, src_feats_counter = build_vocab(
247
+ opts, transforms, n_sample=opts.n_sample
248
+ )
249
+
250
+ logger.info(f"Counters src: {len(src_counter)}")
251
+ logger.info(f"Counters tgt: {len(tgt_counter)}")
252
+ for i, feat_counter in enumerate(src_feats_counter):
253
+ logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
254
+
255
+ def save_counter(counter, save_path):
256
+ check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
257
+ with open(save_path, "w", encoding="utf8") as fo:
258
+ for tok, count in counter.most_common():
259
+ fo.write(tok + "\t" + str(count) + "\n")
260
+
261
+ if opts.share_vocab:
262
+ src_counter += tgt_counter
263
+ tgt_counter = src_counter
264
+ logger.info(f"Counters after share:{len(src_counter)}")
265
+ save_counter(src_counter, opts.src_vocab)
266
+ else:
267
+ save_counter(src_counter, opts.src_vocab)
268
+ save_counter(tgt_counter, opts.tgt_vocab)
269
+
270
+ for i, c in enumerate(src_feats_counter):
271
+ save_counter(c, f"{opts.src_vocab}_feat{i}")
272
+
273
+
274
+ def _get_parser():
275
+ parser = ArgumentParser(description="build_vocab.py")
276
+ dynamic_prepare_opts(parser, build_vocab_only=True)
277
+ return parser
278
+
279
+
280
+ def main():
281
+ parser = _get_parser()
282
+ opts, unknown = parser.parse_known_args()
283
+ build_vocab_main(opts)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
onmt/bin/release_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ description="Release an OpenNMT-py model for inference"
9
+ )
10
+ parser.add_argument("--model", "-m", help="The model path", required=True)
11
+ parser.add_argument("--output", "-o", help="The output path", required=True)
12
+ parser.add_argument(
13
+ "--format",
14
+ choices=["pytorch", "ctranslate2"],
15
+ default="pytorch",
16
+ help="The format of the released model",
17
+ )
18
+ parser.add_argument(
19
+ "--quantization",
20
+ "-q",
21
+ choices=["int8", "int16", "float16", "int8_float16"],
22
+ default=None,
23
+ help="Quantization type for CT2 model.",
24
+ )
25
+ opt = parser.parse_args()
26
+
27
+ model = torch.load(opt.model, map_location=torch.device("cpu"))
28
+ if opt.format == "pytorch":
29
+ model["optim"] = None
30
+ torch.save(model, opt.output)
31
+ elif opt.format == "ctranslate2":
32
+ import ctranslate2
33
+
34
+ converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
35
+ converter.convert(opt.output, force=True, quantization=opt.quantization)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
onmt/bin/server.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import configargparse
3
+
4
+ from flask import Flask, jsonify, request
5
+ from waitress import serve
6
+ from onmt.translate import TranslationServer, ServerModelError
7
+ import logging
8
+ from logging.handlers import RotatingFileHandler
9
+
10
+ STATUS_OK = "ok"
11
+ STATUS_ERROR = "error"
12
+
13
+
14
+ def start(config_file, url_root="./translator", host="0.0.0.0", port=5000, debug=False):
15
+ def prefix_route(route_function, prefix="", mask="{0}{1}"):
16
+ def newroute(route, *args, **kwargs):
17
+ return route_function(mask.format(prefix, route), *args, **kwargs)
18
+
19
+ return newroute
20
+
21
+ if debug:
22
+ logger = logging.getLogger("main")
23
+ log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
24
+ file_handler = RotatingFileHandler(
25
+ "debug_requests.log", maxBytes=1000000, backupCount=10
26
+ )
27
+ file_handler.setFormatter(log_format)
28
+ logger.addHandler(file_handler)
29
+
30
+ app = Flask(__name__)
31
+ app.route = prefix_route(app.route, url_root)
32
+ translation_server = TranslationServer()
33
+ translation_server.start(config_file)
34
+
35
+ @app.route("/models", methods=["GET"])
36
+ def get_models():
37
+ out = translation_server.list_models()
38
+ return jsonify(out)
39
+
40
+ @app.route("/health", methods=["GET"])
41
+ def health():
42
+ out = {}
43
+ out["status"] = STATUS_OK
44
+ return jsonify(out)
45
+
46
+ @app.route("/clone_model/<int:model_id>", methods=["POST"])
47
+ def clone_model(model_id):
48
+ out = {}
49
+ data = request.get_json(force=True)
50
+ timeout = -1
51
+ if "timeout" in data:
52
+ timeout = data["timeout"]
53
+ del data["timeout"]
54
+
55
+ opt = data.get("opt", None)
56
+ try:
57
+ model_id, load_time = translation_server.clone_model(model_id, opt, timeout)
58
+ except ServerModelError as e:
59
+ out["status"] = STATUS_ERROR
60
+ out["error"] = str(e)
61
+ else:
62
+ out["status"] = STATUS_OK
63
+ out["model_id"] = model_id
64
+ out["load_time"] = load_time
65
+
66
+ return jsonify(out)
67
+
68
+ @app.route("/unload_model/<int:model_id>", methods=["GET"])
69
+ def unload_model(model_id):
70
+ out = {"model_id": model_id}
71
+
72
+ try:
73
+ translation_server.unload_model(model_id)
74
+ out["status"] = STATUS_OK
75
+ except Exception as e:
76
+ out["status"] = STATUS_ERROR
77
+ out["error"] = str(e)
78
+
79
+ return jsonify(out)
80
+
81
+ @app.route("/translate", methods=["POST"])
82
+ def translate():
83
+ inputs = request.get_json(force=True)
84
+ if debug:
85
+ logger.info(inputs)
86
+ out = {}
87
+ try:
88
+ trans, scores, n_best, _, aligns, align_scores = translation_server.run(
89
+ inputs
90
+ )
91
+ assert len(trans) == len(inputs) * n_best
92
+ assert len(scores) == len(inputs) * n_best
93
+ assert len(aligns) == len(inputs) * n_best
94
+
95
+ out = [[] for _ in range(n_best)]
96
+ for i in range(len(trans)):
97
+ response = {
98
+ "src": inputs[i // n_best]["src"],
99
+ "tgt": trans[i],
100
+ "n_best": n_best,
101
+ "pred_score": scores[i],
102
+ }
103
+ if len(aligns[i]) > 0 and aligns[i][0] is not None:
104
+ response["align"] = aligns[i]
105
+ response["align_score"] = align_scores[i]
106
+ out[i % n_best].append(response)
107
+ except ServerModelError as e:
108
+ model_id = inputs[0].get("id")
109
+ if debug:
110
+ logger.warning(
111
+ "Unload model #{} " "because of an error".format(model_id)
112
+ )
113
+ translation_server.models[model_id].unload()
114
+ out["error"] = str(e)
115
+ out["status"] = STATUS_ERROR
116
+ if debug:
117
+ logger.info(out)
118
+ return jsonify(out)
119
+
120
+ @app.route("/to_cpu/<int:model_id>", methods=["GET"])
121
+ def to_cpu(model_id):
122
+ out = {"model_id": model_id}
123
+ translation_server.models[model_id].to_cpu()
124
+
125
+ out["status"] = STATUS_OK
126
+ return jsonify(out)
127
+
128
+ @app.route("/to_gpu/<int:model_id>", methods=["GET"])
129
+ def to_gpu(model_id):
130
+ out = {"model_id": model_id}
131
+ translation_server.models[model_id].to_gpu()
132
+
133
+ out["status"] = STATUS_OK
134
+ return jsonify(out)
135
+
136
+ serve(app, host=host, port=port)
137
+
138
+
139
+ def _get_parser():
140
+ parser = configargparse.ArgumentParser(
141
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
142
+ description="OpenNMT-py REST Server",
143
+ )
144
+ parser.add_argument("--ip", type=str, default="0.0.0.0")
145
+ parser.add_argument("--port", type=int, default="5000")
146
+ parser.add_argument("--url_root", type=str, default="/translator")
147
+ parser.add_argument("--debug", "-d", action="store_true")
148
+ parser.add_argument(
149
+ "--config", "-c", type=str, default="./available_models/conf.json"
150
+ )
151
+ return parser
152
+
153
+
154
+ def main():
155
+ parser = _get_parser()
156
+ args = parser.parse_args()
157
+ start(
158
+ args.config,
159
+ url_root=args.url_root,
160
+ host=args.ip,
161
+ port=args.port,
162
+ debug=args.debug,
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
onmt/bin/train.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Train models with dynamic data."""
3
+ import torch
4
+ from functools import partial
5
+ from onmt.utils.distributed import ErrorHandler, spawned_train
6
+ from onmt.utils.misc import set_random_seed
7
+ from onmt.utils.logging import init_logger, logger
8
+ from onmt.utils.parse import ArgumentParser
9
+ from onmt.opts import train_opts
10
+ from onmt.train_single import main as single_main
11
+
12
+
13
+ # Set sharing strategy manually instead of default based on the OS.
14
+ # torch.multiprocessing.set_sharing_strategy('file_system')
15
+
16
+
17
+ def train(opt):
18
+ init_logger(opt.log_file)
19
+
20
+ ArgumentParser.validate_train_opts(opt)
21
+ ArgumentParser.update_model_opts(opt)
22
+ ArgumentParser.validate_model_opts(opt)
23
+
24
+ set_random_seed(opt.seed, False)
25
+
26
+ train_process = partial(single_main)
27
+
28
+ nb_gpu = len(opt.gpu_ranks)
29
+
30
+ if opt.world_size > 1:
31
+ mp = torch.multiprocessing.get_context("spawn")
32
+ # Create a thread to listen for errors in the child processes.
33
+ error_queue = mp.SimpleQueue()
34
+ error_handler = ErrorHandler(error_queue)
35
+ # Train with multiprocessing.
36
+ procs = []
37
+ for device_id in range(nb_gpu):
38
+ procs.append(
39
+ mp.Process(
40
+ target=spawned_train,
41
+ args=(train_process, opt, device_id, error_queue),
42
+ daemon=False,
43
+ )
44
+ )
45
+ procs[device_id].start()
46
+ logger.info(" Starting process pid: %d " % procs[device_id].pid)
47
+ error_handler.add_child(procs[device_id].pid)
48
+ for p in procs:
49
+ p.join()
50
+
51
+ elif nb_gpu == 1: # case 1 GPU only
52
+ train_process(opt, device_id=0)
53
+ else: # case only CPU
54
+ train_process(opt, device_id=-1)
55
+
56
+
57
+ def _get_parser():
58
+ parser = ArgumentParser(description="train.py")
59
+ train_opts(parser)
60
+ return parser
61
+
62
+
63
+ def main():
64
+ parser = _get_parser()
65
+
66
+ opt, unknown = parser.parse_known_args()
67
+ train(opt)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
onmt/bin/translate.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ from onmt.utils.logging import init_logger
4
+ from onmt.translate.translator import build_translator
5
+ from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
6
+ from onmt.inputters.inputter import IterOnDevice
7
+ from onmt.transforms import get_transforms_cls
8
+ from onmt.constants import CorpusTask
9
+ import onmt.opts as opts
10
+ from onmt.utils.parse import ArgumentParser
11
+ from onmt.utils.misc import use_gpu, set_random_seed
12
+
13
+
14
+ def translate(opt):
15
+ ArgumentParser.validate_translate_opts(opt)
16
+ ArgumentParser._get_all_transform_translate(opt)
17
+ ArgumentParser._validate_transforms_opts(opt)
18
+ ArgumentParser.validate_translate_opts_dynamic(opt)
19
+ logger = init_logger(opt.log_file)
20
+
21
+ set_random_seed(opt.seed, use_gpu(opt))
22
+
23
+ translator = build_translator(opt, logger=logger, report_score=True)
24
+
25
+ transforms_cls = get_transforms_cls(opt._all_transform)
26
+
27
+ infer_iter = build_dynamic_dataset_iter(
28
+ opt,
29
+ transforms_cls,
30
+ translator.vocabs,
31
+ task=CorpusTask.INFER,
32
+ copy=translator.copy_attn,
33
+ )
34
+
35
+ infer_iter = IterOnDevice(infer_iter, opt.gpu)
36
+
37
+ _, _ = translator._translate(
38
+ infer_iter,
39
+ transform=infer_iter.transform,
40
+ attn_debug=opt.attn_debug,
41
+ align_debug=opt.align_debug,
42
+ )
43
+
44
+
45
+ def _get_parser():
46
+ parser = ArgumentParser(description="translate.py")
47
+
48
+ opts.config_opts(parser)
49
+ opts.translate_opts(parser, dynamic=True)
50
+ return parser
51
+
52
+
53
+ def main():
54
+ parser = _get_parser()
55
+ opt = parser.parse_args()
56
+ translate(opt)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
onmt/constants.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define constant values used across the project."""
2
+
3
+
4
+ class DefaultTokens(object):
5
+ PAD = "<blank>"
6
+ BOS = "<s>"
7
+ EOS = "</s>"
8
+ UNK = "<unk>"
9
+ MASK = "<mask>"
10
+ VOCAB_PAD = "averyunlikelytoken"
11
+ SENT_FULL_STOPS = [".", "?", "!"]
12
+ PHRASE_TABLE_SEPARATOR = "|||"
13
+ ALIGNMENT_SEPARATOR = " ||| "
14
+ SEP = "⦅newline⦆"
15
+ MASK_BEFORE = "⦅_mask_before_⦆"
16
+
17
+
18
+ class CorpusName(object):
19
+ VALID = "valid"
20
+ TRAIN = "train"
21
+ SAMPLE = "sample"
22
+ INFER = "infer"
23
+
24
+
25
+ class CorpusTask(object):
26
+ TRAIN = "train"
27
+ VALID = "valid"
28
+ INFER = "infer"
29
+
30
+
31
+ class SubwordMarker(object):
32
+ SPACER = "▁"
33
+ JOINER = "■"
34
+ BEGIN_UPPERCASE = "⦅mrk_begin_case_region_U⦆"
35
+ END_UPPERCASE = "⦅mrk_end_case_region_U⦆"
36
+ BEGIN_CASED = "⦅mrk_case_modifier_C⦆"
37
+
38
+
39
+ class ModelTask(object):
40
+ LANGUAGE_MODEL = "lm"
41
+ SEQ2SEQ = "seq2seq"
onmt/decoders/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module defining decoders."""
2
+ import os
3
+ import importlib
4
+ from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
5
+ from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
6
+ from onmt.decoders.cnn_decoder import CNNDecoder
7
+
8
+
9
+ str2dec = {
10
+ "rnn": StdRNNDecoder,
11
+ "ifrnn": InputFeedRNNDecoder,
12
+ "cnn": CNNDecoder,
13
+ "transformer": TransformerDecoder,
14
+ "transformer_lm": TransformerLMDecoder,
15
+ }
16
+
17
+ __all__ = [
18
+ "DecoderBase",
19
+ "TransformerDecoder",
20
+ "StdRNNDecoder",
21
+ "CNNDecoder",
22
+ "InputFeedRNNDecoder",
23
+ "str2dec",
24
+ "TransformerLMDecoder",
25
+ ]
26
+
27
+
28
+ def get_decoders_cls(decoders_names):
29
+ """Return valid encoder class indicated in `decoders_names`."""
30
+ decoders_cls = {}
31
+ for name in decoders_names:
32
+ if name not in str2dec:
33
+ raise ValueError("%s decoder not supported!" % name)
34
+ decoders_cls[name] = str2dec[name]
35
+ return decoders_cls
36
+
37
+
38
+ def register_decoder(name):
39
+ """Encoder register that can be used to add new encoder class."""
40
+
41
+ def register_decoder_cls(cls):
42
+ if name in str2dec:
43
+ raise ValueError("Cannot register duplicate decoder ({})".format(name))
44
+ if not issubclass(cls, DecoderBase):
45
+ raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
46
+ str2dec[name] = cls
47
+ __all__.append(cls.__name__) # added to be complete
48
+ return cls
49
+
50
+ return register_decoder_cls
51
+
52
+
53
+ # Auto import python files in this directory
54
+ decoder_dir = os.path.dirname(__file__)
55
+ for file in os.listdir(decoder_dir):
56
+ path = os.path.join(decoder_dir, file)
57
+ if (
58
+ not file.startswith("_")
59
+ and not file.startswith(".")
60
+ and (file.endswith(".py") or os.path.isdir(path))
61
+ ):
62
+ file_name = file[: file.find(".py")] if file.endswith(".py") else file
63
+ module = importlib.import_module("onmt.decoders." + file_name)
onmt/decoders/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
onmt/decoders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.84 kB). View file
 
onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc ADDED
Binary file (7.32 kB). View file
 
onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc ADDED
Binary file (4.02 kB). View file
 
onmt/decoders/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
onmt/decoders/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (11.5 kB). View file
 
onmt/decoders/__pycache__/ensemble.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
onmt/decoders/__pycache__/ensemble.cpython-38.pyc ADDED
Binary file (7.13 kB). View file
 
onmt/decoders/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
onmt/decoders/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (20.4 kB). View file
 
onmt/decoders/cnn_decoder.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the CNN Decoder part of
2
+ "Convolutional Sequence to Sequence Learning"
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from onmt.modules import ConvMultiStepAttention, GlobalAttention
8
+ from onmt.utils.cnn_factory import shape_transform, GatedConv
9
+ from onmt.decoders.decoder import DecoderBase
10
+
11
+ SCALE_WEIGHT = 0.5**0.5
12
+
13
+
14
+ class CNNDecoder(DecoderBase):
15
+ """Decoder based on "Convolutional Sequence to Sequence Learning"
16
+ :cite:`DBLP:journals/corr/GehringAGYD17`.
17
+
18
+ Consists of residual convolutional layers, with ConvMultiStepAttention.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ num_layers,
24
+ hidden_size,
25
+ attn_type,
26
+ copy_attn,
27
+ cnn_kernel_width,
28
+ dropout,
29
+ embeddings,
30
+ copy_attn_type,
31
+ ):
32
+ super(CNNDecoder, self).__init__()
33
+
34
+ self.cnn_kernel_width = cnn_kernel_width
35
+ self.embeddings = embeddings
36
+
37
+ # Decoder State
38
+ self.state = {}
39
+
40
+ input_size = self.embeddings.embedding_size
41
+ self.linear = nn.Linear(input_size, hidden_size)
42
+ self.conv_layers = nn.ModuleList(
43
+ [
44
+ GatedConv(hidden_size, cnn_kernel_width, dropout, True)
45
+ for i in range(num_layers)
46
+ ]
47
+ )
48
+ self.attn_layers = nn.ModuleList(
49
+ [ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
50
+ )
51
+
52
+ # CNNDecoder has its own attention mechanism.
53
+ # Set up a separate copy attention layer if needed.
54
+ assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
55
+ if copy_attn:
56
+ self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
57
+ else:
58
+ self.copy_attn = None
59
+
60
+ @classmethod
61
+ def from_opt(cls, opt, embeddings):
62
+ """Alternate constructor."""
63
+ return cls(
64
+ opt.dec_layers,
65
+ opt.dec_hid_size,
66
+ opt.global_attention,
67
+ opt.copy_attn,
68
+ opt.cnn_kernel_width,
69
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
70
+ embeddings,
71
+ opt.copy_attn_type,
72
+ )
73
+
74
+ def init_state(self, _, enc_out, enc_hidden):
75
+ """Init decoder state."""
76
+ self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
77
+ self.state["previous_input"] = None
78
+
79
+ def map_state(self, fn):
80
+ self.state["src"] = fn(self.state["src"], 0)
81
+ if self.state["previous_input"] is not None:
82
+ self.state["previous_input"] = fn(self.state["previous_input"], 0)
83
+
84
+ def detach_state(self):
85
+ self.state["previous_input"] = self.state["previous_input"].detach()
86
+
87
+ def forward(self, tgt, enc_out, step=None, **kwargs):
88
+ """See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
89
+
90
+ if self.state["previous_input"] is not None:
91
+ tgt = torch.cat([self.state["previous_input"], tgt], 1)
92
+
93
+ dec_outs = []
94
+ attns = {"std": []}
95
+ if self.copy_attn is not None:
96
+ attns["copy"] = []
97
+
98
+ emb = self.embeddings(tgt)
99
+ assert emb.dim() == 3 # batch x len x embedding_dim
100
+
101
+ tgt_emb = emb
102
+ # The output of CNNEncoder.
103
+ enc_out_t = enc_out
104
+ # The combination of output of CNNEncoder and source embeddings.
105
+ enc_out_c = self.state["src"]
106
+
107
+ emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
108
+ linear_out = self.linear(emb_reshape)
109
+ x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
110
+ x = shape_transform(x)
111
+
112
+ pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
113
+
114
+ pad = pad.type_as(x)
115
+ base_target_emb = x
116
+
117
+ for conv, attention in zip(self.conv_layers, self.attn_layers):
118
+ new_target_input = torch.cat([pad, x], 2)
119
+ out = conv(new_target_input)
120
+ c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
121
+ x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
122
+
123
+ dec_outs = x.squeeze(3).transpose(1, 2)
124
+
125
+ # Process the result and update the attentions.
126
+ if self.state["previous_input"] is not None:
127
+ dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
128
+ attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
129
+ attn = torch.stack([attn])
130
+ attns["std"] = attn
131
+ if self.copy_attn is not None:
132
+ attns["copy"] = attn
133
+
134
+ # Update the state.
135
+ self.state["previous_input"] = tgt
136
+ # TODO change the way attns is returned dict => list or tuple (onnx)
137
+ return dec_outs, attns
138
+
139
+ def update_dropout(self, dropout, attention_dropout=None):
140
+ for layer in self.conv_layers:
141
+ layer.dropout.p = dropout
onmt/decoders/decoder.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from onmt.modules.stacked_rnn import StackedLSTM, StackedGRU
5
+ from onmt.modules import context_gate_factory, GlobalAttention
6
+ from onmt.utils.rnn_factory import rnn_factory
7
+
8
+
9
+ class DecoderBase(nn.Module):
10
+ """Abstract class for decoders.
11
+
12
+ Args:
13
+ attentional (bool): The decoder returns non-empty attention.
14
+ """
15
+
16
+ def __init__(self, attentional=True):
17
+ super(DecoderBase, self).__init__()
18
+ self.attentional = attentional
19
+
20
+ @classmethod
21
+ def from_opt(cls, opt, embeddings):
22
+ """Alternate constructor.
23
+
24
+ Subclasses should override this method.
25
+ """
26
+
27
+ raise NotImplementedError
28
+
29
+
30
+ class RNNDecoderBase(DecoderBase):
31
+ """Base recurrent attention-based decoder class.
32
+
33
+ Specifies the interface used by different decoder types
34
+ and required by :class:`~onmt.models.NMTModel`.
35
+
36
+ Args:
37
+ rnn_type (str):
38
+ style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
39
+ bidirectional_encoder (bool) : use with a bidirectional encoder
40
+ num_layers (int) : number of stacked layers
41
+ hidden_size (int) : hidden size of each layer
42
+ attn_type (str) : see :class:`~onmt.modules.GlobalAttention`
43
+ attn_func (str) : see :class:`~onmt.modules.GlobalAttention`
44
+ coverage_attn (str): see :class:`~onmt.modules.GlobalAttention`
45
+ context_gate (str): see :class:`~onmt.modules.ContextGate`
46
+ copy_attn (bool): setup a separate copy attention mechanism
47
+ dropout (float) : dropout value for :class:`torch.nn.Dropout`
48
+ embeddings (onmt.modules.Embeddings): embedding module to use
49
+ reuse_copy_attn (bool): reuse the attention for copying
50
+ copy_attn_type (str): The copy attention style. See
51
+ :class:`~onmt.modules.GlobalAttention`.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ rnn_type,
57
+ bidirectional_encoder,
58
+ num_layers,
59
+ hidden_size,
60
+ attn_type="general",
61
+ attn_func="softmax",
62
+ coverage_attn=False,
63
+ context_gate=None,
64
+ copy_attn=False,
65
+ dropout=0.0,
66
+ embeddings=None,
67
+ reuse_copy_attn=False,
68
+ copy_attn_type="general",
69
+ ):
70
+ super(RNNDecoderBase, self).__init__(
71
+ attentional=attn_type != "none" and attn_type is not None
72
+ )
73
+
74
+ self.bidirectional_encoder = bidirectional_encoder
75
+ self.num_layers = num_layers
76
+ self.hidden_size = hidden_size
77
+ self.embeddings = embeddings
78
+ self.dropout = nn.Dropout(dropout)
79
+
80
+ # Decoder state
81
+ self.state = {}
82
+
83
+ # Build the RNN.
84
+ self.rnn = self._build_rnn(
85
+ rnn_type,
86
+ input_size=self._input_size,
87
+ hidden_size=hidden_size,
88
+ num_layers=num_layers,
89
+ dropout=dropout,
90
+ )
91
+
92
+ # Set up the context gate.
93
+ self.context_gate = None
94
+ if context_gate is not None:
95
+ self.context_gate = context_gate_factory(
96
+ context_gate, self._input_size, hidden_size, hidden_size, hidden_size
97
+ )
98
+
99
+ # Set up the standard attention.
100
+ self._coverage = coverage_attn
101
+ if not self.attentional:
102
+ if self._coverage:
103
+ raise ValueError("Cannot use coverage term with no attention.")
104
+ self.attn = None
105
+ else:
106
+ self.attn = GlobalAttention(
107
+ hidden_size,
108
+ coverage=coverage_attn,
109
+ attn_type=attn_type,
110
+ attn_func=attn_func,
111
+ )
112
+
113
+ if copy_attn and not reuse_copy_attn:
114
+ if copy_attn_type == "none" or copy_attn_type is None:
115
+ raise ValueError("Cannot use copy_attn with copy_attn_type none")
116
+ self.copy_attn = GlobalAttention(
117
+ hidden_size, attn_type=copy_attn_type, attn_func=attn_func
118
+ )
119
+ else:
120
+ self.copy_attn = None
121
+
122
+ self._reuse_copy_attn = reuse_copy_attn and copy_attn
123
+ if self._reuse_copy_attn and not self.attentional:
124
+ raise ValueError("Cannot reuse copy attention with no attention.")
125
+
126
+ @classmethod
127
+ def from_opt(cls, opt, embeddings):
128
+ """Alternate constructor."""
129
+ return cls(
130
+ opt.rnn_type,
131
+ opt.brnn,
132
+ opt.dec_layers,
133
+ opt.dec_hid_size,
134
+ opt.global_attention,
135
+ opt.global_attention_function,
136
+ opt.coverage_attn,
137
+ opt.context_gate,
138
+ opt.copy_attn,
139
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
140
+ embeddings,
141
+ opt.reuse_copy_attn,
142
+ opt.copy_attn_type,
143
+ )
144
+
145
+ def init_state(self, src, _, enc_final_hs):
146
+ """Initialize decoder state with last state of the encoder."""
147
+
148
+ def _fix_enc_hidden(hidden):
149
+ # The encoder hidden is (layers*directions) x batch x dim.
150
+ # We need to convert it to layers x batch x (directions*dim).
151
+ if self.bidirectional_encoder:
152
+ hidden = torch.cat(
153
+ [hidden[0 : hidden.size(0) : 2], hidden[1 : hidden.size(0) : 2]], 2
154
+ )
155
+ return hidden
156
+
157
+ if isinstance(enc_final_hs, tuple): # LSTM
158
+ self.state["hidden"] = tuple(
159
+ _fix_enc_hidden(enc_hid) for enc_hid in enc_final_hs
160
+ )
161
+ else: # GRU
162
+ self.state["hidden"] = (_fix_enc_hidden(enc_final_hs),)
163
+
164
+ # Init the input feed.
165
+ batch_size = self.state["hidden"][0].size(1)
166
+
167
+ h_size = (batch_size, self.hidden_size)
168
+ self.state["input_feed"] = (
169
+ self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
170
+ )
171
+
172
+ self.state["coverage"] = None
173
+
174
+ def map_state(self, fn):
175
+ self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
176
+ self.state["input_feed"] = fn(self.state["input_feed"], 1)
177
+ if self._coverage and self.state["coverage"] is not None:
178
+ self.state["coverage"] = fn(self.state["coverage"], 1)
179
+
180
+ def detach_state(self):
181
+ self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
182
+ self.state["input_feed"] = self.state["input_feed"].detach()
183
+ if self._coverage and self.state["coverage"] is not None:
184
+ self.state["coverage"] = self.state["coverage"].detach()
185
+
186
+ def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
187
+ """
188
+ Args:
189
+ tgt (LongTensor): sequences of padded tokens
190
+ ``(batch, tgt_len, nfeats)``.
191
+ enc_out (FloatTensor): vectors from the encoder
192
+ ``(batch, src_len, hidden)``.
193
+ src_len (LongTensor): the padded source lengths
194
+ ``(batch,)``.
195
+
196
+ Returns:
197
+ (FloatTensor, dict[str, FloatTensor]):
198
+
199
+ * dec_outs: output from the decoder (after attn)
200
+ ``(batch, tgt_len, hidden)``.
201
+ * attns: distribution over src at each tgt
202
+ ``(batch, tgt_len, src_len)``.
203
+ """
204
+ dec_state, dec_outs, attns = self._run_forward_pass(
205
+ tgt, enc_out, src_len=src_len
206
+ )
207
+
208
+ # Update the state with the result.
209
+ if not isinstance(dec_state, tuple):
210
+ dec_state = (dec_state,)
211
+ self.state["hidden"] = dec_state
212
+
213
+ # Concatenates sequence of tensors along a new dimension.
214
+ # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
215
+ # (in particular in case of SRU) it was not raising error in 0.3
216
+ # since stack(Variable) was allowed.
217
+ # In 0.4, SRU returns a tensor that shouldn't be stacke
218
+ if type(dec_outs) == list:
219
+ dec_outs = torch.stack(dec_outs, dim=1)
220
+ for k in attns:
221
+ if type(attns[k]) == list:
222
+ attns[k] = torch.stack(attns[k])
223
+
224
+ self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0)
225
+ self.state["coverage"] = None
226
+ if "coverage" in attns:
227
+ self.state["coverage"] = attns["coverage"][-1, :, :].unsqueeze(0)
228
+
229
+ return dec_outs, attns
230
+
231
+ def update_dropout(self, dropout, attention_dropout=None):
232
+ self.dropout.p = dropout
233
+ self.embeddings.update_dropout(dropout)
234
+
235
+
236
+ class StdRNNDecoder(RNNDecoderBase):
237
+ """Standard fully batched RNN decoder with attention.
238
+
239
+ Faster implementation, uses CuDNN for implementation.
240
+ See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
241
+
242
+
243
+ Based around the approach from
244
+ "Neural Machine Translation By Jointly Learning To Align and Translate"
245
+ :cite:`Bahdanau2015`
246
+
247
+
248
+ Implemented without input_feeding and currently with no `coverage_attn`
249
+ or `copy_attn` support.
250
+ """
251
+
252
+ def _run_forward_pass(self, tgt, enc_out, src_len=None):
253
+ """
254
+ Private helper for running the specific RNN forward pass.
255
+ Must be overriden by all subclasses.
256
+
257
+ Args:
258
+ tgt (LongTensor): a sequence of input tokens tensors
259
+ ``(batch, tgt_len, nfeats)``.
260
+ enc_out (FloatTensor): output(tensor sequence) from the
261
+ encoder RNN of size ``(batch, src_len, hidden_size)``.
262
+ src_len (LongTensor): the source enc_out lengths.
263
+
264
+ Returns:
265
+ (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):
266
+
267
+ * dec_state: final hidden state from the decoder.
268
+ * dec_outs: an array of output of every time
269
+ step from the decoder.
270
+ * attns: a dictionary of different
271
+ type of attention Tensor array of every time
272
+ step from the decoder.
273
+ """
274
+
275
+ assert self.copy_attn is None # TODO, no support yet.
276
+ assert not self._coverage # TODO, no support yet.
277
+
278
+ attns = {}
279
+ emb = self.embeddings(tgt)
280
+
281
+ if isinstance(self.rnn, nn.GRU):
282
+ rnn_out, dec_state = self.rnn(emb, self.state["hidden"][0])
283
+ else:
284
+ rnn_out, dec_state = self.rnn(emb, self.state["hidden"])
285
+
286
+ tgt_batch, tgt_len, _ = tgt.size()
287
+
288
+ # Calculate the attention.
289
+ if not self.attentional:
290
+ dec_outs = rnn_out
291
+ else:
292
+ dec_outs, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
293
+ attns["std"] = p_attn
294
+
295
+ # Calculate the context gate.
296
+ if self.context_gate is not None:
297
+ dec_outs = self.context_gate(
298
+ emb.view(-1, emb.size(2)),
299
+ rnn_out.view(-1, rnn_out.size(2)),
300
+ dec_outs.view(-1, dec_outs.size(2)),
301
+ )
302
+ dec_outs = dec_outs.view(tgt_batch, tgt_len, self.hidden_size)
303
+
304
+ dec_outs = self.dropout(dec_outs)
305
+
306
+ return dec_state, dec_outs, attns
307
+
308
+ def _build_rnn(self, rnn_type, **kwargs):
309
+ rnn, _ = rnn_factory(rnn_type, **kwargs)
310
+ return rnn
311
+
312
+ @property
313
+ def _input_size(self):
314
+ return self.embeddings.embedding_size
315
+
316
+
317
+ class InputFeedRNNDecoder(RNNDecoderBase):
318
+ """Input feeding based decoder.
319
+
320
+ See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
321
+
322
+ Based around the input feeding approach from
323
+ "Effective Approaches to Attention-based Neural Machine Translation"
324
+ :cite:`Luong2015`
325
+
326
+ """
327
+
328
+ def _run_forward_pass(self, tgt, enc_out, src_len=None):
329
+ """
330
+ See StdRNNDecoder._run_forward_pass() for description
331
+ of arguments and return values.
332
+ """
333
+ # Additional args check.
334
+ input_feed = self.state["input_feed"].squeeze(0)
335
+
336
+ dec_outs = []
337
+ attns = {}
338
+ if self.attn is not None:
339
+ attns["std"] = []
340
+ if self.copy_attn is not None or self._reuse_copy_attn:
341
+ attns["copy"] = []
342
+ if self._coverage:
343
+ attns["coverage"] = []
344
+
345
+ emb = self.embeddings(tgt)
346
+ assert emb.dim() == 3 # batch x len x embedding_dim
347
+
348
+ dec_state = self.state["hidden"]
349
+
350
+ coverage = (
351
+ self.state["coverage"].squeeze(0)
352
+ if self.state["coverage"] is not None
353
+ else None
354
+ )
355
+
356
+ # Input feed concatenates hidden state with
357
+ # input at every time step.
358
+ for emb_t in emb.split(1, dim=1):
359
+ dec_in = torch.cat([emb_t.squeeze(1), input_feed], 1)
360
+ rnn_out, dec_state = self.rnn(dec_in, dec_state)
361
+ if self.attentional:
362
+ dec_out, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
363
+ attns["std"].append(p_attn)
364
+ else:
365
+ dec_out = rnn_out
366
+ if self.context_gate is not None:
367
+ # TODO: context gate should be employed
368
+ # instead of second RNN transform.
369
+ dec_out = self.context_gate(dec_in, rnn_out, dec_out)
370
+ dec_out = self.dropout(dec_out)
371
+ input_feed = dec_out
372
+
373
+ dec_outs += [dec_out]
374
+
375
+ # Update the coverage attention.
376
+ # attns["coverage"] is actually c^(t+1) of See et al(2017)
377
+ # 1-index shifted
378
+ if self._coverage:
379
+ coverage = p_attn if coverage is None else p_attn + coverage
380
+ attns["coverage"] += [coverage]
381
+
382
+ if self.copy_attn is not None:
383
+ _, copy_attn = self.copy_attn(dec_out, enc_out)
384
+ attns["copy"] += [copy_attn]
385
+ elif self._reuse_copy_attn:
386
+ attns["copy"] = attns["std"]
387
+
388
+ return dec_state, dec_outs, attns
389
+
390
+ def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout):
391
+ assert rnn_type != "SRU", (
392
+ "SRU doesn't support input feed! " "Please set -input_feed 0!"
393
+ )
394
+ stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU
395
+ return stacked_cell(num_layers, input_size, hidden_size, dropout)
396
+
397
+ @property
398
+ def _input_size(self):
399
+ """Using input feed by concatenating input with attention vectors."""
400
+ return self.embeddings.embedding_size + self.hidden_size
401
+
402
+ def update_dropout(self, dropout, attention_dropout=None):
403
+ self.dropout.p = dropout
404
+ self.rnn.dropout.p = dropout
405
+ self.embeddings.update_dropout(dropout)
onmt/decoders/ensemble.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ensemble decoding.
2
+
3
+ Decodes using multiple models simultaneously,
4
+ combining their prediction distributions by averaging.
5
+ All models in the ensemble must share a target vocabulary.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from onmt.encoders.encoder import EncoderBase
12
+ from onmt.decoders.decoder import DecoderBase
13
+ from onmt.models import NMTModel
14
+ import onmt.model_builder
15
+
16
+
17
+ class EnsembleDecoderOutput(object):
18
+ """Wrapper around multiple decoder final hidden states."""
19
+
20
+ def __init__(self, model_dec_outs):
21
+ self.model_dec_outs = tuple(model_dec_outs)
22
+
23
+ def squeeze(self, dim=None):
24
+ """Delegate squeeze to avoid modifying
25
+ :func:`onmt.translate.translator.Translator.translate_batch()`
26
+ """
27
+ return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])
28
+
29
+ def __getitem__(self, index):
30
+ return self.model_dec_outs[index]
31
+
32
+
33
+ class EnsembleEncoder(EncoderBase):
34
+ """Dummy Encoder that delegates to individual real Encoders."""
35
+
36
+ def __init__(self, model_encoders):
37
+ super(EnsembleEncoder, self).__init__()
38
+ self.model_encoders = nn.ModuleList(model_encoders)
39
+
40
+ def forward(self, src, src_len=None):
41
+ enc_out, enc_final_hs, _ = zip(
42
+ *[model_encoder(src, src_len) for model_encoder in self.model_encoders]
43
+ )
44
+ return enc_out, enc_final_hs, src_len
45
+
46
+
47
+ class EnsembleDecoder(DecoderBase):
48
+ """Dummy Decoder that delegates to individual real Decoders."""
49
+
50
+ def __init__(self, model_decoders):
51
+ model_decoders = nn.ModuleList(model_decoders)
52
+ attentional = any([dec.attentional for dec in model_decoders])
53
+ super(EnsembleDecoder, self).__init__(attentional)
54
+ self.model_decoders = model_decoders
55
+
56
+ def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
57
+ """See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
58
+ # src_len is a single tensor shared between all models.
59
+ # This assumption will not hold if Translator is modified
60
+ # to calculate src_len as something other than the length
61
+ # of the input.
62
+ dec_outs, attns = zip(
63
+ *[
64
+ model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
65
+ for i, model_decoder in enumerate(self.model_decoders)
66
+ ]
67
+ )
68
+ mean_attns = self.combine_attns(attns)
69
+ return EnsembleDecoderOutput(dec_outs), mean_attns
70
+
71
+ def combine_attns(self, attns):
72
+ result = {}
73
+ for key in attns[0].keys():
74
+ result[key] = torch.stack(
75
+ [attn[key] for attn in attns if attn[key] is not None]
76
+ ).mean(0)
77
+ return result
78
+
79
+ def init_state(self, src, enc_out, enc_hidden):
80
+ """See :obj:`RNNDecoderBase.init_state()`"""
81
+ for i, model_decoder in enumerate(self.model_decoders):
82
+ model_decoder.init_state(src, enc_out[i], enc_hidden[i])
83
+
84
+ def map_state(self, fn):
85
+ for model_decoder in self.model_decoders:
86
+ model_decoder.map_state(fn)
87
+
88
+
89
+ class EnsembleGenerator(nn.Module):
90
+ """
91
+ Dummy Generator that delegates to individual real Generators,
92
+ and then averages the resulting target distributions.
93
+ """
94
+
95
+ def __init__(self, model_generators, raw_probs=False):
96
+ super(EnsembleGenerator, self).__init__()
97
+ self.model_generators = nn.ModuleList(model_generators)
98
+ self._raw_probs = raw_probs
99
+
100
+ def forward(self, hidden, attn=None, src_map=None):
101
+ """
102
+ Compute a distribution over the target dictionary
103
+ by averaging distributions from models in the ensemble.
104
+ All models in the ensemble must share a target vocabulary.
105
+ """
106
+ distributions = torch.stack(
107
+ [
108
+ mg(h) if attn is None else mg(h, attn, src_map)
109
+ for h, mg in zip(hidden, self.model_generators)
110
+ ]
111
+ )
112
+ if self._raw_probs:
113
+ return torch.log(torch.exp(distributions).mean(0))
114
+ else:
115
+ return distributions.mean(0)
116
+
117
+
118
+ class EnsembleModel(NMTModel):
119
+ """Dummy NMTModel wrapping individual real NMTModels."""
120
+
121
+ def __init__(self, models, raw_probs=False):
122
+ encoder = EnsembleEncoder(model.encoder for model in models)
123
+ decoder = EnsembleDecoder(model.decoder for model in models)
124
+ super(EnsembleModel, self).__init__(encoder, decoder)
125
+ self.generator = EnsembleGenerator(
126
+ [model.generator for model in models], raw_probs
127
+ )
128
+ self.models = nn.ModuleList(models)
129
+
130
+
131
+ def load_test_model(opt, device_id=0):
132
+ """Read in multiple models for ensemble."""
133
+ shared_vocabs = None
134
+ shared_model_opt = None
135
+ models = []
136
+ for model_path in opt.models:
137
+ vocabs, model, model_opt = onmt.model_builder.load_test_model(
138
+ opt, device_id, model_path=model_path
139
+ )
140
+ if shared_vocabs is None:
141
+ shared_vocabs = vocabs
142
+ else:
143
+ assert (
144
+ shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
145
+ ), "Ensemble models must use the same vocabs "
146
+ models.append(model)
147
+ if shared_model_opt is None:
148
+ shared_model_opt = model_opt
149
+ ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
150
+ return shared_vocabs, ensemble_model, shared_model_opt
onmt/decoders/transformer.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of "Attention is All You Need" and of
3
+ subsequent transformer based architectures
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from onmt.decoders.decoder import DecoderBase
9
+ from onmt.modules import MultiHeadedAttention, AverageAttention
10
+ from onmt.modules.position_ffn import PositionwiseFeedForward
11
+ from onmt.modules.position_ffn import ActivationFunction
12
+ from onmt.utils.misc import sequence_mask
13
+ from onmt.modules.rmsnorm import RMSNorm
14
+
15
+
16
+ class TransformerDecoderLayerBase(nn.Module):
17
+ def __init__(
18
+ self,
19
+ d_model,
20
+ heads,
21
+ d_ff,
22
+ dropout,
23
+ attention_dropout,
24
+ self_attn_type="scaled-dot",
25
+ max_relative_positions=0,
26
+ relative_positions_buckets=0,
27
+ aan_useffn=False,
28
+ full_context_alignment=False,
29
+ alignment_heads=0,
30
+ pos_ffn_activation_fn=ActivationFunction.relu,
31
+ add_qkvbias=False,
32
+ num_kv=0,
33
+ add_ffnbias=True,
34
+ parallel_residual=False,
35
+ shared_layer_norm=False,
36
+ layer_norm="standard",
37
+ norm_eps=1e-6,
38
+ use_ckpting=[],
39
+ parallel_gpu=1,
40
+ ):
41
+ """
42
+ Args:
43
+ d_model (int): the dimension of keys/values/queries in
44
+ :class:`MultiHeadedAttention`, also the input size of
45
+ the first-layer of the :class:`PositionwiseFeedForward`.
46
+ heads (int): the number of heads for MultiHeadedAttention.
47
+ d_ff (int): the second-layer of the
48
+ :class:`PositionwiseFeedForward`.
49
+ dropout (float): dropout in residual, self-attn(dot) and
50
+ feed-forward
51
+ attention_dropout (float): dropout in context_attn (and
52
+ self-attn(avg))
53
+ self_attn_type (string): type of self-attention scaled-dot,
54
+ average
55
+ max_relative_positions (int):
56
+ Max distance between inputs in relative positions
57
+ representations
58
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
59
+ full_context_alignment (bool):
60
+ whether enable an extra full context decoder forward for
61
+ alignment
62
+ alignment_heads (int):
63
+ N. of cross attention heads to use for alignment guiding
64
+ pos_ffn_activation_fn (ActivationFunction):
65
+ activation function choice for PositionwiseFeedForward layer
66
+ add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
67
+ layer_norm (string): type of layer normalization standard/rms
68
+ norm_eps (float): layer norm epsilon
69
+
70
+ """
71
+ super(TransformerDecoderLayerBase, self).__init__()
72
+
73
+ self.self_attn_type = self_attn_type
74
+ if self_attn_type == "scaled-dot":
75
+ self.self_attn = MultiHeadedAttention(
76
+ heads,
77
+ d_model,
78
+ dropout=attention_dropout,
79
+ max_relative_positions=max_relative_positions,
80
+ relative_positions_buckets=relative_positions_buckets,
81
+ attn_type="self",
82
+ add_qkvbias=add_qkvbias,
83
+ num_kv=num_kv,
84
+ use_ckpting=use_ckpting,
85
+ parallel_gpu=parallel_gpu,
86
+ )
87
+ elif self_attn_type == "average":
88
+ self.self_attn = AverageAttention(
89
+ d_model, dropout=attention_dropout, aan_useffn=aan_useffn
90
+ )
91
+
92
+ self.feed_forward = PositionwiseFeedForward(
93
+ d_model,
94
+ d_ff,
95
+ dropout,
96
+ pos_ffn_activation_fn,
97
+ add_ffnbias,
98
+ parallel_residual,
99
+ layer_norm,
100
+ norm_eps,
101
+ use_ckpting=use_ckpting,
102
+ parallel_gpu=parallel_gpu,
103
+ )
104
+ self.parallel_residual = parallel_residual
105
+ self.shared_layer_norm = shared_layer_norm
106
+ if layer_norm == "standard":
107
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=norm_eps)
108
+ if parallel_residual and not shared_layer_norm:
109
+ self.layer_norm_res = nn.LayerNorm(d_model, eps=norm_eps)
110
+ elif layer_norm == "rms":
111
+ self.layer_norm_1 = RMSNorm(d_model, eps=norm_eps)
112
+ if parallel_residual and not shared_layer_norm:
113
+ self.layer_norm_res = RMSNorm(d_model, eps=norm_eps)
114
+ else:
115
+ raise ValueError(f"{layer_norm} layer norm type is not supported")
116
+
117
+ self.dropout = nn.Dropout(dropout)
118
+ self.full_context_alignment = full_context_alignment
119
+ self.alignment_heads = alignment_heads
120
+
121
+ def forward(self, *args, **kwargs):
122
+ """Extend `_forward` for (possibly) multiple decoder pass:
123
+ Always a default (future masked) decoder forward pass,
124
+ Possibly a second future aware decoder pass for joint learn
125
+ full context alignement, :cite:`garg2019jointly`.
126
+
127
+ Args:
128
+ * All arguments of _forward, of which
129
+ with_align (bool): needed to compute attn_align
130
+ return_attn (bool): to force MHA to return attns
131
+
132
+ Returns:
133
+ (FloatTensor, FloatTensor, FloatTensor or None):
134
+
135
+ * layer_out ``(batch_size, T, model_dim)``
136
+ * top_attn ``(batch_size, T, src_len)``
137
+ * attn_align ``(batch_size, T, src_len)`` or None
138
+ """
139
+ with_align = kwargs.pop("with_align", False)
140
+ layer_out, attns = self._forward(*args, **kwargs)
141
+ top_attn = None if attns is None else attns[:, 0, :, :].contiguous()
142
+ attn_align = None
143
+ if with_align:
144
+ if self.full_context_alignment:
145
+ # return _, (B, Q_len, K_len)
146
+ _, attns = self._forward(*args, **kwargs, future=True)
147
+
148
+ if self.alignment_heads > 0:
149
+ attns = attns[:, : self.alignment_heads, :, :].contiguous()
150
+ # layer average attention across heads, get ``(B, Q, K)``
151
+ # Case 1: no full_context, no align heads -> layer avg baseline
152
+ # Case 2: no full_context, 1 align heads -> guided align
153
+ # Case 3: full_context, 1 align heads -> full cte guided align
154
+ attn_align = attns.mean(dim=1)
155
+ return layer_out, top_attn, attn_align
156
+
157
+ def update_dropout(self, dropout, attention_dropout):
158
+ self.self_attn.update_dropout(attention_dropout)
159
+ self.feed_forward.update_dropout(dropout)
160
+ self.dropout.p = dropout
161
+
162
+ def _forward(self, *args, **kwargs):
163
+ raise NotImplementedError
164
+
165
+ def _compute_dec_mask(self, tgt_pad_mask, future):
166
+ tgt_len = tgt_pad_mask.size(-1)
167
+ if not future: # apply future_mask, result mask in (B, T, T)
168
+ future_mask = torch.ones(
169
+ [tgt_len, tgt_len],
170
+ device=tgt_pad_mask.device,
171
+ dtype=torch.uint8,
172
+ )
173
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
174
+ # BoolTensor was introduced in pytorch 1.2
175
+ try:
176
+ future_mask = future_mask.bool()
177
+ except AttributeError:
178
+ pass
179
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
180
+ else: # only mask padding, result mask in (B, 1, T)
181
+ dec_mask = tgt_pad_mask
182
+ return dec_mask
183
+
184
+ def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
185
+ if self.self_attn_type == "scaled-dot":
186
+ return self.self_attn(
187
+ norm_layer_in,
188
+ norm_layer_in,
189
+ norm_layer_in,
190
+ mask=dec_mask,
191
+ step=step,
192
+ return_attn=return_attn,
193
+ )
194
+ elif self.self_attn_type == "average":
195
+ return self.self_attn(norm_layer_in, mask=dec_mask, step=step)
196
+ else:
197
+ raise ValueError(f"self attention {type(self.self_attn)} not supported")
198
+
199
+
200
+ class TransformerDecoderLayer(TransformerDecoderLayerBase):
201
+ """Transformer Decoder layer block in Pre-Norm style.
202
+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
203
+ providing better converge speed and performance. This is also the actual
204
+ implementation in tensor2tensor and also avalable in fairseq.
205
+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
206
+
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ d_model,
212
+ heads,
213
+ d_ff,
214
+ dropout,
215
+ attention_dropout,
216
+ self_attn_type="scaled-dot",
217
+ max_relative_positions=0,
218
+ relative_positions_buckets=0,
219
+ aan_useffn=False,
220
+ full_context_alignment=False,
221
+ alignment_heads=0,
222
+ pos_ffn_activation_fn=ActivationFunction.relu,
223
+ add_qkvbias=False,
224
+ num_kv=0,
225
+ add_ffnbias=True,
226
+ parallel_residual=False,
227
+ shared_layer_norm=False,
228
+ layer_norm="standard",
229
+ norm_eps=1e-6,
230
+ use_ckpting=[],
231
+ parallel_gpu=1,
232
+ ):
233
+ """
234
+ Args:
235
+ See TransformerDecoderLayerBase
236
+ """
237
+ super(TransformerDecoderLayer, self).__init__(
238
+ d_model,
239
+ heads,
240
+ d_ff,
241
+ dropout,
242
+ attention_dropout,
243
+ self_attn_type,
244
+ max_relative_positions,
245
+ relative_positions_buckets,
246
+ aan_useffn,
247
+ full_context_alignment,
248
+ alignment_heads,
249
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
250
+ add_qkvbias=add_qkvbias,
251
+ num_kv=num_kv,
252
+ add_ffnbias=add_ffnbias,
253
+ parallel_residual=parallel_residual,
254
+ shared_layer_norm=shared_layer_norm,
255
+ layer_norm=layer_norm,
256
+ norm_eps=norm_eps,
257
+ use_ckpting=use_ckpting,
258
+ parallel_gpu=parallel_gpu,
259
+ )
260
+ self.context_attn = MultiHeadedAttention(
261
+ heads,
262
+ d_model,
263
+ dropout=attention_dropout,
264
+ attn_type="context",
265
+ add_qkvbias=add_qkvbias,
266
+ num_kv=num_kv,
267
+ use_ckpting=use_ckpting,
268
+ parallel_gpu=parallel_gpu,
269
+ )
270
+ if layer_norm == "standard":
271
+ self.layer_norm_2 = nn.LayerNorm(d_model, eps=norm_eps)
272
+ elif layer_norm == "rms":
273
+ self.layer_norm_2 = RMSNorm(d_model, eps=norm_eps)
274
+ else:
275
+ raise ValueError(f"{layer_norm} layer norm type is not supported")
276
+
277
+ def update_dropout(self, dropout, attention_dropout):
278
+ super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout)
279
+ self.context_attn.update_dropout(attention_dropout)
280
+
281
+ def _forward(
282
+ self,
283
+ layer_in,
284
+ enc_out,
285
+ src_pad_mask,
286
+ tgt_pad_mask,
287
+ step=None,
288
+ future=False,
289
+ return_attn=False,
290
+ ):
291
+ """A naive forward pass for transformer decoder.
292
+
293
+ # T: could be 1 in the case of stepwise decoding or tgt_len
294
+
295
+ Args:
296
+ layer_in (FloatTensor): ``(batch_size, T, model_dim)``
297
+ enc_out (FloatTensor): ``(batch_size, src_len, model_dim)``
298
+ src_pad_mask (bool): ``(batch_size, 1, src_len)``
299
+ tgt_pad_mask (bool): ``(batch_size, 1, T)``
300
+ step (int or None): stepwise decoding counter
301
+ future (bool): If set True, do not apply future_mask.
302
+ return_attn (bool) : if set True requires attns output
303
+
304
+ Returns:
305
+ (FloatTensor, FloatTensor):
306
+
307
+ * layer_out ``(batch_size, T, model_dim)``
308
+ * attns ``(batch_size, head, T, src_len)``
309
+
310
+ """
311
+ dec_mask = None
312
+ src_pad_mask = src_pad_mask.unsqueeze(1) # [B,1,1,slen]
313
+
314
+ if layer_in.size(1) > 1:
315
+ # masking is necessary when sequence length is greater than one
316
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
317
+ dec_mask = dec_mask.unsqueeze(1)
318
+ dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
319
+ src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1)
320
+ # mask now are (batch x 1 x tlen x s or t len)
321
+ # 1 = heads to be expanded in MHA
322
+
323
+ norm_layer_in = self.layer_norm_1(layer_in)
324
+
325
+ self_attn, _ = self._forward_self_attn(norm_layer_in, dec_mask, step)
326
+
327
+ if self.parallel_residual:
328
+ ctx_attn, attns = self.context_attn(
329
+ enc_out,
330
+ enc_out,
331
+ norm_layer_in,
332
+ mask=src_pad_mask,
333
+ return_attn=return_attn,
334
+ )
335
+ # feed_forward applies residual, so we remove and apply residual with un-normed
336
+ layer_out = (
337
+ self.feed_forward(norm_layer_in)
338
+ - norm_layer_in
339
+ + layer_in
340
+ + self.dropout(self_attn)
341
+ + ctx_attn
342
+ )
343
+ else:
344
+ query = self.dropout(self_attn) + layer_in
345
+ norm_query = self.layer_norm_2(query)
346
+ ctx_attn, attns = self.context_attn(
347
+ enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
348
+ )
349
+ layer_out = self.feed_forward(self.dropout(ctx_attn) + query)
350
+
351
+ return layer_out, attns
352
+
353
+
354
+ class TransformerDecoderBase(DecoderBase):
355
+ def __init__(
356
+ self, d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
357
+ ):
358
+ super(TransformerDecoderBase, self).__init__()
359
+
360
+ self.embeddings = embeddings
361
+
362
+ # Decoder State
363
+ self.state = {}
364
+
365
+ # previously, there was a GlobalAttention module here for copy
366
+ # attention. But it was never actually used -- the "copy" attention
367
+ # just reuses the context attention.
368
+ self._copy = copy_attn
369
+ if layer_norm == "standard":
370
+ self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
371
+ elif layer_norm == "rms":
372
+ self.layer_norm = RMSNorm(d_model, eps=norm_eps)
373
+ else:
374
+ raise ValueError(f"{layer_norm} layer norm type is not supported")
375
+
376
+ self.alignment_layer = alignment_layer
377
+
378
+ @classmethod
379
+ def from_opt(cls, opt, embeddings):
380
+ """Alternate constructor."""
381
+ return cls(
382
+ opt.dec_layers,
383
+ opt.dec_hid_size,
384
+ opt.heads,
385
+ opt.transformer_ff,
386
+ opt.copy_attn,
387
+ opt.self_attn_type,
388
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
389
+ opt.attention_dropout[0]
390
+ if type(opt.attention_dropout) is list
391
+ else opt.attention_dropout,
392
+ embeddings,
393
+ opt.max_relative_positions,
394
+ opt.relative_positions_buckets,
395
+ opt.aan_useffn,
396
+ opt.full_context_alignment,
397
+ opt.alignment_layer,
398
+ alignment_heads=opt.alignment_heads,
399
+ pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
400
+ add_qkvbias=opt.add_qkvbias,
401
+ num_kv=opt.num_kv,
402
+ add_ffnbias=opt.add_ffnbias,
403
+ parallel_residual=opt.parallel_residual,
404
+ shared_layer_norm=opt.shared_layer_norm,
405
+ layer_norm=opt.layer_norm,
406
+ norm_eps=opt.norm_eps,
407
+ use_ckpting=opt.use_ckpting,
408
+ parallel_gpu=opt.world_size
409
+ if opt.parallel_mode == "tensor_parallel"
410
+ else 1,
411
+ )
412
+
413
+ def init_state(self, src, enc_out, enc_final_hs):
414
+ """Initialize decoder state."""
415
+ self.state["src"] = src
416
+
417
+ def map_state(self, fn):
418
+ if self.state["src"] is not None:
419
+ self.state["src"] = fn(self.state["src"], 0)
420
+ for layer in self.transformer_layers:
421
+ if hasattr(layer, "context_attn"):
422
+ if layer.context_attn.layer_cache[1]["keys"].numel() != 0:
423
+ x = fn(layer.context_attn.layer_cache[1]["keys"], 0)
424
+ y = fn(layer.context_attn.layer_cache[1]["values"], 0)
425
+ layer.context_attn.layer_cache = True, {"keys": x, "values": y}
426
+ if isinstance(layer.self_attn, AverageAttention):
427
+ if layer.self_attn.layer_cache[1]["prev_g"].numel() != 0:
428
+ x = fn(layer.self_attn.layer_cache[1]["prev_g"], 0)
429
+ layer.self_attn.layer_cache = True, {"prev_g": x}
430
+ else:
431
+ if layer.self_attn.layer_cache[1]["keys"].numel() != 0:
432
+ x = fn(layer.self_attn.layer_cache[1]["keys"], 0)
433
+ y = fn(layer.self_attn.layer_cache[1]["values"], 0)
434
+ layer.self_attn.layer_cache = True, {"keys": x, "values": y}
435
+
436
+ def detach_state(self):
437
+ raise NotImplementedError
438
+
439
+ def forward(self, *args, **kwargs):
440
+ raise NotImplementedError
441
+
442
+ def update_dropout(self, dropout, attention_dropout):
443
+ self.embeddings.update_dropout(dropout)
444
+ for layer in self.transformer_layers:
445
+ layer.update_dropout(dropout, attention_dropout)
446
+
447
+
448
+ class TransformerDecoder(TransformerDecoderBase):
449
+ """The Transformer decoder from "Attention is All You Need".
450
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
451
+
452
+ Args:
453
+ num_layers (int): number of decoder layers.
454
+ d_model (int): size of the model
455
+ heads (int): number of heads
456
+ d_ff (int): size of the inner FF layer
457
+ copy_attn (bool): if using a separate copy attention
458
+ self_attn_type (str): type of self-attention scaled-dot, average
459
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
460
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
461
+ embeddings (onmt.modules.Embeddings):
462
+ embeddings to use, should have positional encodings
463
+ max_relative_positions (int):
464
+ Max distance between inputs in relative positions representations
465
+ relative_positions_buckets (int):
466
+ Number of buckets when using relative position bias
467
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
468
+ full_context_alignment (bool):
469
+ whether enable an extra full context decoder forward for alignment
470
+ alignment_layer (int): N° Layer to supervise with for alignment guiding
471
+ alignment_heads (int):
472
+ N. of cross attention heads to use for alignment guiding
473
+ add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
474
+ layer_norm (string): type of layer normalization standard/rms
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ num_layers,
480
+ d_model,
481
+ heads,
482
+ d_ff,
483
+ copy_attn,
484
+ self_attn_type,
485
+ dropout,
486
+ attention_dropout,
487
+ embeddings,
488
+ max_relative_positions,
489
+ relative_positions_buckets,
490
+ aan_useffn,
491
+ full_context_alignment,
492
+ alignment_layer,
493
+ alignment_heads,
494
+ pos_ffn_activation_fn=ActivationFunction.relu,
495
+ add_qkvbias=False,
496
+ num_kv=0,
497
+ add_ffnbias=True,
498
+ parallel_residual=False,
499
+ shared_layer_norm=False,
500
+ layer_norm="standard",
501
+ norm_eps=1e-6,
502
+ use_ckpting=[],
503
+ parallel_gpu=1,
504
+ ):
505
+ super(TransformerDecoder, self).__init__(
506
+ d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
507
+ )
508
+
509
+ self.transformer_layers = nn.ModuleList(
510
+ [
511
+ TransformerDecoderLayer(
512
+ d_model,
513
+ heads,
514
+ d_ff,
515
+ dropout,
516
+ attention_dropout,
517
+ self_attn_type=self_attn_type,
518
+ max_relative_positions=max_relative_positions,
519
+ relative_positions_buckets=relative_positions_buckets,
520
+ aan_useffn=aan_useffn,
521
+ full_context_alignment=full_context_alignment,
522
+ alignment_heads=alignment_heads,
523
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
524
+ add_qkvbias=add_qkvbias,
525
+ num_kv=num_kv,
526
+ add_ffnbias=add_ffnbias,
527
+ parallel_residual=parallel_residual,
528
+ shared_layer_norm=shared_layer_norm,
529
+ layer_norm=layer_norm,
530
+ norm_eps=norm_eps,
531
+ use_ckpting=use_ckpting,
532
+ parallel_gpu=parallel_gpu,
533
+ )
534
+ for i in range(num_layers)
535
+ ]
536
+ )
537
+
538
+ def detach_state(self):
539
+ self.state["src"] = self.state["src"].detach()
540
+
541
+ def forward(self, tgt, enc_out=None, step=None, **kwargs):
542
+ """
543
+ Decode, possibly stepwise.
544
+ when training step is always None, when decoding, step increases
545
+ tgt (Tensor): batch x tlen x feats
546
+ enc_out (Tensor): encoder output (batch x slen x model_dim)
547
+ """
548
+ if enc_out is None:
549
+ enc_out = self.embeddings(tgt)
550
+ if step == 0:
551
+ self._init_cache(enc_out)
552
+ elif step is None:
553
+ for layer in self.transformer_layers:
554
+ if isinstance(layer.self_attn, AverageAttention):
555
+ layer.self_attn.layer_cache = False, {"prev_g": torch.tensor([])}
556
+ else:
557
+ layer.self_attn.layer_cache = (
558
+ False,
559
+ {"keys": torch.tensor([]), "values": torch.tensor([])},
560
+ )
561
+ layer.context_attn.layer_cache = (
562
+ False,
563
+ {"keys": torch.tensor([]), "values": torch.tensor([])},
564
+ )
565
+
566
+ emb = self.embeddings(tgt, step=step)
567
+ dec_out = emb
568
+ assert emb.dim() == 3 # len x batch x embedding_dim
569
+
570
+ pad_idx = self.embeddings.word_padding_idx
571
+ src_lens = kwargs["src_len"]
572
+ src_max_len = self.state["src"].shape[1]
573
+ src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen]
574
+ src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen]
575
+ tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
576
+
577
+ with_align = kwargs.pop("with_align", False)
578
+ return_attn = with_align or self._copy
579
+ attn_aligns = []
580
+
581
+ for layer in self.transformer_layers:
582
+ dec_out, attn, attn_align = layer(
583
+ dec_out,
584
+ enc_out,
585
+ src_pad_mask,
586
+ tgt_pad_mask,
587
+ step=step,
588
+ with_align=with_align,
589
+ return_attn=return_attn,
590
+ )
591
+ if attn_align is not None:
592
+ attn_aligns.append(attn_align)
593
+
594
+ dec_out = self.layer_norm(dec_out)
595
+
596
+ attns = {"std": attn}
597
+ if self._copy:
598
+ attns["copy"] = attn
599
+ if with_align:
600
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
601
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
602
+
603
+ # TODO change the way attns is returned dict => list or tuple (onnx)
604
+ return dec_out, attns
605
+
606
+ def _init_cache(self, enc_out):
607
+ batch_size = enc_out.size(0)
608
+ depth = enc_out.size(-1)
609
+
610
+ for layer in self.transformer_layers:
611
+ # first value set to True triggered by the beginning of decoding
612
+ # layer_cache becomes active in the MultiHeadedAttention fwd
613
+ layer.context_attn.layer_cache = (
614
+ True,
615
+ {
616
+ "keys": torch.tensor([], device=enc_out.device),
617
+ "values": torch.tensor([], device=enc_out.device),
618
+ },
619
+ )
620
+ if isinstance(layer.self_attn, AverageAttention):
621
+ layer.self_attn.layer_cache = True, {
622
+ "prev_g": torch.zeros(
623
+ (batch_size, 1, depth), device=enc_out.device
624
+ ).to(enc_out.dtype)
625
+ }
626
+ else:
627
+ layer.self_attn.layer_cache = (
628
+ True,
629
+ {
630
+ "keys": torch.tensor([], device=enc_out.device),
631
+ "values": torch.tensor([], device=enc_out.device),
632
+ },
633
+ )
634
+
635
+
636
+ class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
637
+ """Transformer Decoder only layer block in GPT style.
638
+ Args:
639
+ See TransformerDecoderLayerBase
640
+ """
641
+
642
+ def _forward(
643
+ self, layer_in, tgt_pad_mask, step=None, future=False, return_attn=False
644
+ ):
645
+ """A naive forward pass for transformer decoder.
646
+
647
+ # T: could be 1 in the case of stepwise decoding or tgt_len
648
+
649
+ Args:
650
+ layer_in (FloatTensor): ``(batch_size, T, model_dim)``
651
+ tgt_pad_mask (bool): ``(batch_size, 1, T)``
652
+ layer_cache (dict or None): cached layer info when stepwise decode
653
+ step (int or None): stepwise decoding counter
654
+ future (bool): If set True, do not apply future_mask.
655
+ return_attn (bool): If set True return attn
656
+
657
+ Returns:
658
+ (FloatTensor, FloatTensor):
659
+
660
+ * layer_out ``(batch_size, T, model_dim)``
661
+ * attns ``(batch_size, head, T, T)``
662
+
663
+ """
664
+ dec_mask = None
665
+
666
+ if layer_in.size(1) > 1:
667
+ # masking is necessary when sequence length is greater than one
668
+ dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
669
+ dec_mask = dec_mask.unsqueeze(1)
670
+ dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
671
+ # mask now are (batch x 1 x tlen x tlen)
672
+ # 1 = heads to be expanded in MHA
673
+
674
+ norm_layer_in = self.layer_norm_1(layer_in)
675
+
676
+ attn_output, attns = self._forward_self_attn(
677
+ norm_layer_in, dec_mask, step, return_attn=return_attn
678
+ )
679
+
680
+ if self.parallel_residual:
681
+ # feed_forward applies residual, so we remove and apply residual with un-normed
682
+ if not self.shared_layer_norm:
683
+ norm_res_layer_in = self.layer_norm_res(layer_in)
684
+ ff_in = norm_res_layer_in
685
+ else:
686
+ ff_in = norm_layer_in
687
+ layer_out = (
688
+ self.feed_forward(ff_in) - ff_in + layer_in + self.dropout(attn_output)
689
+ )
690
+ else:
691
+ layer_out = self.dropout(attn_output) + layer_in
692
+ layer_out = self.feed_forward(layer_out)
693
+
694
+ return layer_out, attns
695
+
696
+
697
+ class TransformerLMDecoder(TransformerDecoderBase):
698
+ """The Transformer decoder from GPT-2
699
+ Args:
700
+ num_layers (int): number of decoder layers.
701
+ d_model (int): size of the model
702
+ heads (int): number of heads
703
+ d_ff (int): size of the inner FF layer
704
+ copy_attn (bool): if using a separate copy attention
705
+ self_attn_type (str): type of self-attention scaled-dot, average
706
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
707
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
708
+ embeddings (onmt.modules.Embeddings):
709
+ embeddings to use, should have positional encodings
710
+ max_relative_positions (int):
711
+ Max distance between inputs in relative positions representations
712
+ relative_positions_buckets (int):
713
+ Number of buckets when using Relative positions bias
714
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
715
+ add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
716
+ """
717
+
718
+ def __init__(
719
+ self,
720
+ num_layers,
721
+ d_model,
722
+ heads,
723
+ d_ff,
724
+ copy_attn,
725
+ self_attn_type,
726
+ dropout,
727
+ attention_dropout,
728
+ embeddings,
729
+ max_relative_positions,
730
+ relative_positions_buckets,
731
+ aan_useffn,
732
+ full_context_alignment=None,
733
+ alignment_layer=None,
734
+ alignment_heads=None,
735
+ pos_ffn_activation_fn=ActivationFunction.relu,
736
+ add_qkvbias=False,
737
+ num_kv=0,
738
+ add_ffnbias=True,
739
+ parallel_residual=False,
740
+ shared_layer_norm=False,
741
+ layer_norm="standard",
742
+ norm_eps=1e-6,
743
+ use_ckpting=[],
744
+ parallel_gpu=1,
745
+ ):
746
+ super(TransformerLMDecoder, self).__init__(
747
+ d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
748
+ )
749
+ self.transformer_layers = nn.ModuleList(
750
+ [
751
+ TransformerLMDecoderLayer(
752
+ d_model,
753
+ heads,
754
+ d_ff,
755
+ dropout,
756
+ attention_dropout,
757
+ self_attn_type=self_attn_type,
758
+ max_relative_positions=max_relative_positions,
759
+ relative_positions_buckets=relative_positions_buckets,
760
+ aan_useffn=aan_useffn,
761
+ full_context_alignment=None,
762
+ alignment_heads=None,
763
+ pos_ffn_activation_fn=pos_ffn_activation_fn,
764
+ add_qkvbias=add_qkvbias,
765
+ num_kv=num_kv,
766
+ add_ffnbias=add_ffnbias,
767
+ parallel_residual=parallel_residual,
768
+ shared_layer_norm=shared_layer_norm,
769
+ layer_norm=layer_norm,
770
+ norm_eps=norm_eps,
771
+ use_ckpting=use_ckpting,
772
+ parallel_gpu=parallel_gpu,
773
+ )
774
+ for i in range(num_layers)
775
+ ]
776
+ )
777
+
778
+ def init_state(self, src=None, enc_out=None, enc_final_hs=None):
779
+ super(TransformerLMDecoder, self).init_state(None, None, None)
780
+
781
+ def detach_state(self):
782
+ pass
783
+
784
+ def forward(self, tgt, enc_out=None, step=None, **kwargs):
785
+ """Decode, possibly stepwise."""
786
+ if step == 0:
787
+ self._init_cache(tgt)
788
+ elif step is None:
789
+ for layer in self.transformer_layers:
790
+ layer.self_attn.layer_cache = (
791
+ False,
792
+ {"keys": torch.tensor([]), "values": torch.tensor([])},
793
+ )
794
+
795
+ dec_out = self.embeddings(tgt, step=step)
796
+
797
+ assert dec_out.dim() == 3 # batch x len x embedding_dim
798
+
799
+ pad_idx = self.embeddings.word_padding_idx
800
+ tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
801
+
802
+ with_align = kwargs.pop("with_align", False)
803
+ return_attn = with_align or self._copy
804
+ assert not with_align, "TransformerLMDecoder does not support align"
805
+
806
+ for layer in self.transformer_layers:
807
+ dec_out, attn, _ = layer(
808
+ dec_out,
809
+ tgt_pad_mask,
810
+ step=step,
811
+ with_align=with_align,
812
+ return_attn=return_attn,
813
+ )
814
+
815
+ dec_out = self.layer_norm(dec_out)
816
+
817
+ attns = {"std": attn}
818
+ if self._copy:
819
+ attns["copy"] = attn
820
+
821
+ # TODO change the way attns is returned dict => list or tuple (onnx)
822
+ return dec_out, attns
823
+
824
+ def _init_cache(self, tgt=None):
825
+ for layer in self.transformer_layers:
826
+ if isinstance(layer.self_attn, AverageAttention):
827
+ raise NotImplementedError
828
+ else:
829
+ layer.self_attn.layer_cache = (
830
+ True,
831
+ {
832
+ "keys": torch.tensor([], device=tgt.device),
833
+ "values": torch.tensor([], device=tgt.device),
834
+ },
835
+ )
onmt/encoders/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module defining encoders."""
2
+ import os
3
+ import importlib
4
+ from onmt.encoders.encoder import EncoderBase
5
+ from onmt.encoders.transformer import TransformerEncoder
6
+ from onmt.encoders.ggnn_encoder import GGNNEncoder
7
+ from onmt.encoders.rnn_encoder import RNNEncoder
8
+ from onmt.encoders.cnn_encoder import CNNEncoder
9
+ from onmt.encoders.mean_encoder import MeanEncoder
10
+
11
+
12
+ str2enc = {
13
+ "ggnn": GGNNEncoder,
14
+ "rnn": RNNEncoder,
15
+ "brnn": RNNEncoder,
16
+ "cnn": CNNEncoder,
17
+ "transformer": TransformerEncoder,
18
+ "mean": MeanEncoder,
19
+ }
20
+
21
+ __all__ = [
22
+ "EncoderBase",
23
+ "TransformerEncoder",
24
+ "GGNNEncoder",
25
+ "RNNEncoder",
26
+ "CNNEncoder",
27
+ "MeanEncoder",
28
+ "str2enc",
29
+ ]
30
+
31
+
32
+ def get_encoders_cls(encoder_names):
33
+ """Return valid encoder class indicated in `encoder_names`."""
34
+ encoders_cls = {}
35
+ for name in encoder_names:
36
+ if name not in str2enc:
37
+ raise ValueError("%s encoder not supported!" % name)
38
+ encoders_cls[name] = str2enc[name]
39
+ return encoders_cls
40
+
41
+
42
+ def register_encoder(name):
43
+ """Encoder register that can be used to add new encoder class."""
44
+
45
+ def register_encoder_cls(cls):
46
+ if name in str2enc:
47
+ raise ValueError("Cannot register duplicate encoder ({})".format(name))
48
+ if not issubclass(cls, EncoderBase):
49
+ raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
50
+ str2enc[name] = cls
51
+ __all__.append(cls.__name__) # added to be complete
52
+ return cls
53
+
54
+ return register_encoder_cls
55
+
56
+
57
+ # Auto import python files in this directory
58
+ encoder_dir = os.path.dirname(__file__)
59
+ for file in os.listdir(encoder_dir):
60
+ path = os.path.join(encoder_dir, file)
61
+ if (
62
+ not file.startswith("_")
63
+ and not file.startswith(".")
64
+ and (file.endswith(".py") or os.path.isdir(path))
65
+ ):
66
+ file_name = file[: file.find(".py")] if file.endswith(".py") else file
67
+ module = importlib.import_module("onmt.encoders." + file_name)
onmt/encoders/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.13 kB). View file