ReactSeq / onmt /train_single.py
Oopstom's picture
Upload 313 files
c668e80 verified
#!/usr/bin/env python
"""Training on a single process."""
import torch
import sys
from onmt.utils.logging import init_logger, logger
from onmt.utils.parse import ArgumentParser
from onmt.constants import CorpusTask
from onmt.transforms import (
make_transforms,
save_transforms,
get_specials,
get_transforms_cls,
)
from onmt.inputters import build_vocab, IterOnDevice
from onmt.inputters.inputter import dict_to_vocabs, vocabs_to_dict
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.text_corpus import save_transformed_sample
from onmt.model_builder import build_model
from onmt.models.model_saver import load_checkpoint
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
from onmt.modules.embeddings import prepare_pretrained_embeddings
def prepare_transforms_vocabs(opt, transforms_cls):
"""Prepare or dump transforms before training."""
# if transform + options set in 'valid' we need to copy in main
# transform / options for scoring considered as inference
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
if validset_transforms:
opt.transforms = validset_transforms
if opt.data.get("valid", {}).get("tgt_prefix", None):
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
if opt.data.get("valid", {}).get("src_prefix", None):
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
if opt.data.get("valid", {}).get("tgt_suffix", None):
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
if opt.data.get("valid", {}).get("src_suffix", None):
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
specials = get_specials(opt, transforms_cls)
vocabs = build_vocab(opt, specials)
# maybe prepare pretrained embeddings, if any
prepare_pretrained_embeddings(opt, vocabs)
if opt.dump_transforms or opt.n_sample != 0:
transforms = make_transforms(opt, transforms_cls, vocabs)
if opt.dump_transforms:
save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
if opt.n_sample != 0:
logger.warning(
"`-n_sample` != 0: Training will not be started. "
f"Stop after saving {opt.n_sample} samples/corpus."
)
save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
logger.info("Sample saved, please check it before restart training.")
sys.exit()
logger.info(
"The first 10 tokens of the vocabs are:"
f"{vocabs_to_dict(vocabs)['src'][0:10]}"
)
logger.info(f"The decoder start token is: {opt.decoder_start_token}")
return vocabs
def _init_train(opt):
"""Common initilization stuff for all training process.
We need to build or rebuild the vocab in 3 cases:
- training from scratch (train_from is false)
- resume training but transforms have changed
- resume training but vocab file has been modified
"""
ArgumentParser.validate_prepare_opts(opt)
transforms_cls = get_transforms_cls(opt._all_transform)
if opt.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
vocabs = dict_to_vocabs(checkpoint["vocab"])
if (
hasattr(checkpoint["opt"], "_all_transform")
and len(
opt._all_transform.symmetric_difference(
checkpoint["opt"]._all_transform
)
)
!= 0
):
_msg = "configured transforms is different from checkpoint:"
new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform)
old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform)
if len(new_transf) != 0:
_msg += f" +{new_transf}"
if len(old_transf) != 0:
_msg += f" -{old_transf}."
logger.warning(_msg)
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
if opt.update_vocab:
logger.info("Updating checkpoint vocabulary with new vocabulary")
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
else:
checkpoint = None
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
return checkpoint, vocabs, transforms_cls
def configure_process(opt, device_id):
if device_id >= 0:
torch.cuda.set_device(device_id)
set_random_seed(opt.seed, device_id >= 0)
def _get_model_opts(opt, checkpoint=None):
"""Get `model_opt` to build model, may load from `checkpoint` if any."""
if checkpoint is not None:
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
if opt.override_opts:
logger.info("Over-ride model option set to true - use with care")
args = list(opt.__dict__.keys())
model_args = list(model_opt.__dict__.keys())
for arg in args:
if arg in model_args and getattr(opt, arg) != getattr(model_opt, arg):
logger.info(
"Option: %s , value: %s overriding model: %s"
% (arg, getattr(opt, arg), getattr(model_opt, arg))
)
model_opt = opt
else:
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(
model_opt, "tensorboard_log_dir_dated"
):
# ensure tensorboard output is written in the directory
# of previous checkpoints
opt.tensorboard_log_dir_dated = (
model_opt.tensorboard_log_dir_dated
) # noqa: E501
# Override checkpoint's update_embeddings as it defaults to false
model_opt.update_vocab = opt.update_vocab
# Override checkpoint's freezing settings as it defaults to false
model_opt.freeze_encoder = opt.freeze_encoder
model_opt.freeze_decoder = opt.freeze_decoder
else:
model_opt = opt
return model_opt
def main(opt, device_id):
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.
configure_process(opt, device_id)
init_logger(opt.log_file)
checkpoint, vocabs, transforms_cls = _init_train(opt)
model_opt = _get_model_opts(opt, checkpoint=checkpoint)
# Build model.
model = build_model(model_opt, opt, vocabs, checkpoint, device_id)
model.count_parameters(log=logger.info)
trainable = {
"torch.float32": 0,
"torch.float16": 0,
"torch.uint8": 0,
"torch.int8": 0,
}
non_trainable = {
"torch.float32": 0,
"torch.float16": 0,
"torch.uint8": 0,
"torch.int8": 0,
}
for n, p in model.named_parameters():
if p.requires_grad:
trainable[str(p.dtype)] += p.numel()
else:
non_trainable[str(p.dtype)] += p.numel()
logger.info("Trainable parameters = %s" % str(trainable))
logger.info("Non trainable parameters = %s" % str(non_trainable))
logger.info(" * src vocab size = %d" % len(vocabs["src"]))
logger.info(" * tgt vocab size = %d" % len(vocabs["tgt"]))
if "src_feats" in vocabs:
for i, feat_vocab in enumerate(vocabs["src_feats"]):
logger.info(f"* src_feat {i} vocab size = {len(feat_vocab)}")
# Build optimizer.
optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)
del checkpoint
# Build model saver
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id)
trainer = build_trainer(
opt, device_id, model, vocabs, optim, model_saver=model_saver
)
offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0
stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1
_train_iter = build_dynamic_dataset_iter(
opt,
transforms_cls,
vocabs,
task=CorpusTask.TRAIN,
copy=opt.copy_attn,
stride=stride,
offset=offset,
)
train_iter = IterOnDevice(_train_iter, device_id)
valid_iter = build_dynamic_dataset_iter(
opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn
)
if valid_iter is not None:
valid_iter = IterOnDevice(valid_iter, device_id)
if len(opt.gpu_ranks):
logger.info("Starting training on GPU: %s" % opt.gpu_ranks)
else:
logger.info("Starting training on CPU, could be very slow")
train_steps = opt.train_steps
if opt.single_pass and train_steps > 0:
logger.warning("Option single_pass is enabled, ignoring train_steps.")
train_steps = 0
trainer.train(
train_iter,
train_steps,
save_checkpoint_steps=opt.save_checkpoint_steps,
valid_iter=valid_iter,
valid_steps=opt.valid_steps,
)
if trainer.report_manager.tensorboard_writer is not None:
trainer.report_manager.tensorboard_writer.close()