|
|
|
"""Training on a single process.""" |
|
import torch |
|
import sys |
|
|
|
from onmt.utils.logging import init_logger, logger |
|
from onmt.utils.parse import ArgumentParser |
|
from onmt.constants import CorpusTask |
|
from onmt.transforms import ( |
|
make_transforms, |
|
save_transforms, |
|
get_specials, |
|
get_transforms_cls, |
|
) |
|
from onmt.inputters import build_vocab, IterOnDevice |
|
from onmt.inputters.inputter import dict_to_vocabs, vocabs_to_dict |
|
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter |
|
from onmt.inputters.text_corpus import save_transformed_sample |
|
from onmt.model_builder import build_model |
|
from onmt.models.model_saver import load_checkpoint |
|
from onmt.utils.optimizers import Optimizer |
|
from onmt.utils.misc import set_random_seed |
|
from onmt.trainer import build_trainer |
|
from onmt.models import build_model_saver |
|
from onmt.modules.embeddings import prepare_pretrained_embeddings |
|
|
|
|
|
def prepare_transforms_vocabs(opt, transforms_cls): |
|
"""Prepare or dump transforms before training.""" |
|
|
|
|
|
validset_transforms = opt.data.get("valid", {}).get("transforms", None) |
|
if validset_transforms: |
|
opt.transforms = validset_transforms |
|
if opt.data.get("valid", {}).get("tgt_prefix", None): |
|
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) |
|
if opt.data.get("valid", {}).get("src_prefix", None): |
|
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) |
|
if opt.data.get("valid", {}).get("tgt_suffix", None): |
|
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) |
|
if opt.data.get("valid", {}).get("src_suffix", None): |
|
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) |
|
specials = get_specials(opt, transforms_cls) |
|
|
|
vocabs = build_vocab(opt, specials) |
|
|
|
|
|
prepare_pretrained_embeddings(opt, vocabs) |
|
|
|
if opt.dump_transforms or opt.n_sample != 0: |
|
transforms = make_transforms(opt, transforms_cls, vocabs) |
|
if opt.dump_transforms: |
|
save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) |
|
if opt.n_sample != 0: |
|
logger.warning( |
|
"`-n_sample` != 0: Training will not be started. " |
|
f"Stop after saving {opt.n_sample} samples/corpus." |
|
) |
|
save_transformed_sample(opt, transforms, n_sample=opt.n_sample) |
|
logger.info("Sample saved, please check it before restart training.") |
|
sys.exit() |
|
logger.info( |
|
"The first 10 tokens of the vocabs are:" |
|
f"{vocabs_to_dict(vocabs)['src'][0:10]}" |
|
) |
|
logger.info(f"The decoder start token is: {opt.decoder_start_token}") |
|
return vocabs |
|
|
|
|
|
def _init_train(opt): |
|
"""Common initilization stuff for all training process. |
|
We need to build or rebuild the vocab in 3 cases: |
|
- training from scratch (train_from is false) |
|
- resume training but transforms have changed |
|
- resume training but vocab file has been modified |
|
""" |
|
ArgumentParser.validate_prepare_opts(opt) |
|
transforms_cls = get_transforms_cls(opt._all_transform) |
|
if opt.train_from: |
|
|
|
checkpoint = load_checkpoint(ckpt_path=opt.train_from) |
|
vocabs = dict_to_vocabs(checkpoint["vocab"]) |
|
if ( |
|
hasattr(checkpoint["opt"], "_all_transform") |
|
and len( |
|
opt._all_transform.symmetric_difference( |
|
checkpoint["opt"]._all_transform |
|
) |
|
) |
|
!= 0 |
|
): |
|
_msg = "configured transforms is different from checkpoint:" |
|
new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform) |
|
old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform) |
|
if len(new_transf) != 0: |
|
_msg += f" +{new_transf}" |
|
if len(old_transf) != 0: |
|
_msg += f" -{old_transf}." |
|
logger.warning(_msg) |
|
vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
|
if opt.update_vocab: |
|
logger.info("Updating checkpoint vocabulary with new vocabulary") |
|
vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
|
else: |
|
checkpoint = None |
|
vocabs = prepare_transforms_vocabs(opt, transforms_cls) |
|
|
|
return checkpoint, vocabs, transforms_cls |
|
|
|
|
|
def configure_process(opt, device_id): |
|
if device_id >= 0: |
|
torch.cuda.set_device(device_id) |
|
set_random_seed(opt.seed, device_id >= 0) |
|
|
|
|
|
def _get_model_opts(opt, checkpoint=None): |
|
"""Get `model_opt` to build model, may load from `checkpoint` if any.""" |
|
if checkpoint is not None: |
|
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
|
if opt.override_opts: |
|
logger.info("Over-ride model option set to true - use with care") |
|
args = list(opt.__dict__.keys()) |
|
model_args = list(model_opt.__dict__.keys()) |
|
for arg in args: |
|
if arg in model_args and getattr(opt, arg) != getattr(model_opt, arg): |
|
logger.info( |
|
"Option: %s , value: %s overriding model: %s" |
|
% (arg, getattr(opt, arg), getattr(model_opt, arg)) |
|
) |
|
model_opt = opt |
|
else: |
|
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
|
ArgumentParser.update_model_opts(model_opt) |
|
ArgumentParser.validate_model_opts(model_opt) |
|
if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr( |
|
model_opt, "tensorboard_log_dir_dated" |
|
): |
|
|
|
|
|
opt.tensorboard_log_dir_dated = ( |
|
model_opt.tensorboard_log_dir_dated |
|
) |
|
|
|
model_opt.update_vocab = opt.update_vocab |
|
|
|
model_opt.freeze_encoder = opt.freeze_encoder |
|
model_opt.freeze_decoder = opt.freeze_decoder |
|
else: |
|
model_opt = opt |
|
return model_opt |
|
|
|
|
|
def main(opt, device_id): |
|
"""Start training on `device_id`.""" |
|
|
|
|
|
|
|
configure_process(opt, device_id) |
|
init_logger(opt.log_file) |
|
checkpoint, vocabs, transforms_cls = _init_train(opt) |
|
model_opt = _get_model_opts(opt, checkpoint=checkpoint) |
|
|
|
|
|
model = build_model(model_opt, opt, vocabs, checkpoint, device_id) |
|
|
|
model.count_parameters(log=logger.info) |
|
trainable = { |
|
"torch.float32": 0, |
|
"torch.float16": 0, |
|
"torch.uint8": 0, |
|
"torch.int8": 0, |
|
} |
|
non_trainable = { |
|
"torch.float32": 0, |
|
"torch.float16": 0, |
|
"torch.uint8": 0, |
|
"torch.int8": 0, |
|
} |
|
for n, p in model.named_parameters(): |
|
if p.requires_grad: |
|
trainable[str(p.dtype)] += p.numel() |
|
else: |
|
non_trainable[str(p.dtype)] += p.numel() |
|
logger.info("Trainable parameters = %s" % str(trainable)) |
|
logger.info("Non trainable parameters = %s" % str(non_trainable)) |
|
logger.info(" * src vocab size = %d" % len(vocabs["src"])) |
|
logger.info(" * tgt vocab size = %d" % len(vocabs["tgt"])) |
|
if "src_feats" in vocabs: |
|
for i, feat_vocab in enumerate(vocabs["src_feats"]): |
|
logger.info(f"* src_feat {i} vocab size = {len(feat_vocab)}") |
|
|
|
|
|
optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) |
|
|
|
del checkpoint |
|
|
|
|
|
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id) |
|
|
|
trainer = build_trainer( |
|
opt, device_id, model, vocabs, optim, model_saver=model_saver |
|
) |
|
|
|
offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0 |
|
stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1 |
|
|
|
_train_iter = build_dynamic_dataset_iter( |
|
opt, |
|
transforms_cls, |
|
vocabs, |
|
task=CorpusTask.TRAIN, |
|
copy=opt.copy_attn, |
|
stride=stride, |
|
offset=offset, |
|
) |
|
train_iter = IterOnDevice(_train_iter, device_id) |
|
|
|
valid_iter = build_dynamic_dataset_iter( |
|
opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn |
|
) |
|
|
|
if valid_iter is not None: |
|
valid_iter = IterOnDevice(valid_iter, device_id) |
|
|
|
if len(opt.gpu_ranks): |
|
logger.info("Starting training on GPU: %s" % opt.gpu_ranks) |
|
else: |
|
logger.info("Starting training on CPU, could be very slow") |
|
train_steps = opt.train_steps |
|
if opt.single_pass and train_steps > 0: |
|
logger.warning("Option single_pass is enabled, ignoring train_steps.") |
|
train_steps = 0 |
|
|
|
trainer.train( |
|
train_iter, |
|
train_steps, |
|
save_checkpoint_steps=opt.save_checkpoint_steps, |
|
valid_iter=valid_iter, |
|
valid_steps=opt.valid_steps, |
|
) |
|
|
|
if trainer.report_manager.tensorboard_writer is not None: |
|
trainer.report_manager.tensorboard_writer.close() |
|
|