|
|
|
"""Get vocabulary coutings from transformed corpora samples.""" |
|
import os |
|
import copy |
|
import multiprocessing as mp |
|
import pyonmttok |
|
from functools import partial |
|
from onmt.utils.logging import init_logger, logger |
|
from onmt.utils.misc import set_random_seed, check_path |
|
from onmt.utils.parse import ArgumentParser |
|
from onmt.opts import dynamic_prepare_opts |
|
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora |
|
from onmt.inputters.text_utils import process, append_features_to_text |
|
from onmt.transforms import make_transforms, get_transforms_cls |
|
from onmt.constants import CorpusName, CorpusTask |
|
from collections import Counter |
|
|
|
|
|
MAXBUCKETSIZE = 256000 |
|
|
|
|
|
def write_files_from_queues(sample_path, queues): |
|
""" |
|
Standalone process that reads data from |
|
queues in order and write to sample files. |
|
""" |
|
os.makedirs(sample_path, exist_ok=True) |
|
for c_name in queues.keys(): |
|
dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE)) |
|
with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open( |
|
dest_base + ".tgt", "w", encoding="utf-8" |
|
) as f_tgt: |
|
while True: |
|
_next = False |
|
for q in queues[c_name]: |
|
item = q.get() |
|
if item == "blank": |
|
continue |
|
if item == "break": |
|
_next = True |
|
break |
|
_, src_line, tgt_line = item |
|
f_src.write(src_line + "\n") |
|
f_tgt.write(tgt_line + "\n") |
|
if _next: |
|
break |
|
|
|
|
|
def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): |
|
"""Build vocab on (strided) subpart of the data.""" |
|
sub_counter_src = Counter() |
|
sub_counter_tgt = Counter() |
|
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] |
|
datasets_iterables = build_corpora_iters( |
|
corpora, |
|
transforms, |
|
opts.data, |
|
skip_empty_level=opts.skip_empty_level, |
|
stride=stride, |
|
offset=offset, |
|
) |
|
for c_name, c_iter in datasets_iterables.items(): |
|
for i, item in enumerate(c_iter): |
|
maybe_example = process(CorpusTask.TRAIN, [item]) |
|
if maybe_example is not None: |
|
maybe_example = maybe_example[0] |
|
else: |
|
if opts.dump_samples: |
|
build_sub_vocab.queues[c_name][offset].put("blank") |
|
continue |
|
src_line, tgt_line = ( |
|
maybe_example["src"]["src"], |
|
maybe_example["tgt"]["tgt"], |
|
) |
|
sub_counter_src.update(src_line.split(" ")) |
|
sub_counter_tgt.update(tgt_line.split(" ")) |
|
|
|
if "feats" in maybe_example["src"]: |
|
src_feats_lines = maybe_example["src"]["feats"] |
|
for k in range(opts.n_src_feats): |
|
sub_counter_src_feats[k].update(src_feats_lines[k].split(" ")) |
|
else: |
|
src_feats_lines = [] |
|
|
|
if opts.dump_samples: |
|
src_pretty_line = append_features_to_text(src_line, src_feats_lines) |
|
build_sub_vocab.queues[c_name][offset].put( |
|
(i, src_pretty_line, tgt_line) |
|
) |
|
if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample: |
|
if opts.dump_samples: |
|
build_sub_vocab.queues[c_name][offset].put("break") |
|
break |
|
if opts.dump_samples: |
|
build_sub_vocab.queues[c_name][offset].put("break") |
|
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats |
|
|
|
|
|
def init_pool(queues): |
|
"""Add the queues as attribute of the pooled function.""" |
|
build_sub_vocab.queues = queues |
|
|
|
|
|
def build_vocab(opts, transforms, n_sample=3): |
|
"""Build vocabulary from data.""" |
|
|
|
if n_sample == -1: |
|
logger.info(f"n_sample={n_sample}: Build vocab on full datasets.") |
|
elif n_sample > 0: |
|
logger.info(f"Build vocab on {n_sample} transformed examples/corpus.") |
|
else: |
|
raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.") |
|
|
|
if opts.dump_samples: |
|
logger.info( |
|
"The samples on which the vocab is built will be " |
|
"dumped to disk. It may slow down the process." |
|
) |
|
corpora = get_corpora(opts, task=CorpusTask.TRAIN) |
|
counter_src = Counter() |
|
counter_tgt = Counter() |
|
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] |
|
|
|
queues = { |
|
c_name: [ |
|
mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads) |
|
] |
|
for c_name in corpora.keys() |
|
} |
|
sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE) |
|
if opts.dump_samples: |
|
write_process = mp.Process( |
|
target=write_files_from_queues, args=(sample_path, queues), daemon=True |
|
) |
|
write_process.start() |
|
with mp.Pool(opts.num_threads, init_pool, [queues]) as p: |
|
func = partial( |
|
build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads |
|
) |
|
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap( |
|
func, range(0, opts.num_threads) |
|
): |
|
counter_src.update(sub_counter_src) |
|
counter_tgt.update(sub_counter_tgt) |
|
for i in range(opts.n_src_feats): |
|
counter_src_feats[i].update(sub_counter_src_feats[i]) |
|
if opts.dump_samples: |
|
write_process.join() |
|
return counter_src, counter_tgt, counter_src_feats |
|
|
|
|
|
def ingest_tokens(opts, transforms, n_sample, learner, stride, offset): |
|
def _mp_ingest(data): |
|
func = partial(process, CorpusName.TRAIN) |
|
chunk = len(data) // opts.num_threads |
|
with mp.Pool(opts.num_threads) as pool: |
|
buckets = pool.map( |
|
func, |
|
[data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)], |
|
) |
|
for bucket in buckets: |
|
for ex in bucket: |
|
if ex is not None: |
|
src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"]) |
|
learner.ingest(src_line) |
|
learner.ingest(tgt_line) |
|
|
|
corpora = get_corpora(opts, task=CorpusTask.TRAIN) |
|
datasets_iterables = build_corpora_iters( |
|
corpora, |
|
transforms, |
|
opts.data, |
|
skip_empty_level=opts.skip_empty_level, |
|
stride=stride, |
|
offset=offset, |
|
) |
|
to_ingest = [] |
|
for c_name, c_iter in datasets_iterables.items(): |
|
for i, item in enumerate(c_iter): |
|
if n_sample >= 0 and i >= n_sample: |
|
break |
|
if len(to_ingest) >= MAXBUCKETSIZE: |
|
_mp_ingest(to_ingest) |
|
to_ingest = [] |
|
to_ingest.append(item) |
|
_mp_ingest(to_ingest) |
|
|
|
|
|
def make_learner(tokenization_type, symbols): |
|
if tokenization_type == "bpe": |
|
|
|
learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols) |
|
elif tokenization_type == "sentencepiece": |
|
|
|
learner = pyonmttok.SentencePieceLearner( |
|
vocab_size=symbols, character_coverage=0.98 |
|
) |
|
return learner |
|
|
|
|
|
def build_vocab_main(opts): |
|
"""Apply transforms to samples of specified data and build vocab from it. |
|
|
|
Transforms that need vocab will be disabled in this. |
|
Built vocab is saved in plain text format as following and can be pass as |
|
`-src_vocab` (and `-tgt_vocab`) when training: |
|
``` |
|
<tok_0>\t<count_0> |
|
<tok_1>\t<count_1> |
|
``` |
|
""" |
|
|
|
ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True) |
|
assert ( |
|
opts.n_sample == -1 or opts.n_sample > 1 |
|
), f"Illegal argument n_sample={opts.n_sample}." |
|
|
|
logger = init_logger() |
|
set_random_seed(opts.seed, False) |
|
transforms_cls = get_transforms_cls(opts._all_transform) |
|
|
|
if opts.learn_subwords: |
|
logger.info(f"Ingesting {opts.src_subword_type} model from corpus") |
|
learner = make_learner(opts.src_subword_type, opts.learn_subwords_size) |
|
if opts.src_subword_model is not None: |
|
tok_path = opts.src_subword_model |
|
else: |
|
data_dir = os.path.split(opts.save_data)[0] |
|
if not os.path.exists(data_dir): |
|
os.makedirs(data_dir) |
|
tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model") |
|
save_opts = copy.deepcopy(opts) |
|
opts.src_subword_type = "none" |
|
opts.tgt_subword_type = "none" |
|
opts.src_onmttok_kwargs["joiner_annotate"] = False |
|
opts.tgt_onmttok_kwargs["joiner_annotate"] = False |
|
transforms = make_transforms(opts, transforms_cls, None) |
|
ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0) |
|
logger.info(f"Learning {tok_path} model, patience") |
|
learner.learn(tok_path) |
|
opts = save_opts |
|
|
|
transforms = make_transforms(opts, transforms_cls, None) |
|
|
|
logger.info(f"Counter vocab from {opts.n_sample} samples.") |
|
src_counter, tgt_counter, src_feats_counter = build_vocab( |
|
opts, transforms, n_sample=opts.n_sample |
|
) |
|
|
|
logger.info(f"Counters src: {len(src_counter)}") |
|
logger.info(f"Counters tgt: {len(tgt_counter)}") |
|
for i, feat_counter in enumerate(src_feats_counter): |
|
logger.info(f"Counters src feat_{i}: {len(feat_counter)}") |
|
|
|
def save_counter(counter, save_path): |
|
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) |
|
with open(save_path, "w", encoding="utf8") as fo: |
|
for tok, count in counter.most_common(): |
|
fo.write(tok + "\t" + str(count) + "\n") |
|
|
|
if opts.share_vocab: |
|
src_counter += tgt_counter |
|
tgt_counter = src_counter |
|
logger.info(f"Counters after share:{len(src_counter)}") |
|
save_counter(src_counter, opts.src_vocab) |
|
else: |
|
save_counter(src_counter, opts.src_vocab) |
|
save_counter(tgt_counter, opts.tgt_vocab) |
|
|
|
for i, c in enumerate(src_feats_counter): |
|
save_counter(c, f"{opts.src_vocab}_feat{i}") |
|
|
|
|
|
def _get_parser(): |
|
parser = ArgumentParser(description="build_vocab.py") |
|
dynamic_prepare_opts(parser, build_vocab_only=True) |
|
return parser |
|
|
|
|
|
def main(): |
|
parser = _get_parser() |
|
opts, unknown = parser.parse_known_args() |
|
build_vocab_main(opts) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|