Spaces:
Runtime error
Runtime error
import time | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel | |
from tencentpretrain.model_loader import load_model | |
from tencentpretrain.model_saver import save_model | |
from tencentpretrain.model_builder import build_model | |
from tencentpretrain.utils.logging import init_logger | |
from tencentpretrain.utils.optimizers import * | |
from tencentpretrain.utils import * | |
from tencentpretrain.utils.seed import set_seed | |
from tqdm import tqdm | |
def train_and_validate(args): | |
set_seed(args.seed) | |
# Load vocabulary. | |
if args.data_processor == "mt": | |
args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False) | |
args.tgt_vocab = args.tgt_tokenizer.vocab | |
args.tokenizer = str2tokenizer[args.tokenizer](args) | |
args.vocab = args.tokenizer.vocab | |
# Build model. | |
model_for_training = build_model(args) | |
# Load or initialize parameters. | |
if args.pretrained_model_path is not None: | |
# Initialize with pretrained model. | |
model_for_training = load_model(model_for_training, args.pretrained_model_path) | |
else: | |
# Initialize with normal distribution. | |
if args.deep_init: | |
scaled_factor = 1 / math.sqrt(2.0 * args.layers_num) | |
for n, p in list(model_for_training.named_parameters()): | |
if "gamma" not in n and "beta" not in n: | |
if "linear_2.weight" in n or "final_linear.weight" in n: | |
p.data.normal_(0, 0.02 * scaled_factor) | |
elif "linear_2.bias" in n or "final_linear.bias" in n: | |
p.data.zero_() | |
else: | |
p.data.normal_(0, 0.02) | |
else: | |
for n, p in list(model_for_training.named_parameters()): | |
if "gamma" not in n and "beta" not in n: | |
p.data.normal_(0, 0.02) | |
if args.vqgan_model_path is not None: | |
from tencentpretrain.utils.image_tokenizer import build_vqgan_model | |
model_for_dataloader = build_vqgan_model(args) | |
else: | |
model_for_dataloader = None | |
if args.deepspeed: | |
worker(args.local_rank, None, args, model_for_training, model_for_dataloader) | |
elif args.dist_train: | |
# Multiprocessing distributed mode. | |
mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args, model_for_training, model_for_dataloader), daemon=False) | |
elif args.single_gpu: | |
# Single GPU mode. | |
worker(args.gpu_id, None, args, model_for_training, model_for_dataloader) | |
else: | |
# CPU mode. | |
worker(None, None, args, model_for_training, model_for_dataloader) | |
class Trainer(object): | |
def __init__(self, args): | |
self.current_step = 1 | |
self.total_steps = args.total_steps | |
self.accumulation_steps = args.accumulation_steps | |
self.report_steps = args.report_steps | |
self.save_checkpoint_steps = args.save_checkpoint_steps | |
self.output_model_path = args.output_model_path | |
self.start_time = time.time() | |
self.total_loss = 0.0 | |
self.best_loss = float("inf") | |
self.dist_train = args.dist_train | |
self.batch_size = args.batch_size | |
self.world_size = args.world_size | |
self.logger = args.logger | |
def forward_propagation(self, batch, model): | |
raise NotImplementedError | |
def report_and_reset_stats(self): | |
raise NotImplementedError | |
def train(self, args, gpu_id, rank, loader, model, optimizer, scheduler): | |
model.train() | |
loader_iter = iter(loader) | |
while True: | |
# for step in tqdm(range(self.current_step, self.total_steps + 1)): | |
if self.current_step == self.total_steps + 1: | |
break | |
batch = list(next(loader_iter)) | |
self.seq_length = batch[0].size(1) | |
if gpu_id is not None: | |
for i in range(len(batch)): | |
if torch.is_tensor(batch[i]): | |
batch[i] = batch[i].cuda(gpu_id) | |
# print(batch[0].shape, gpu_id) | |
loss = self.forward_propagation(batch, model) | |
if args.deepspeed: | |
model.backward(loss) | |
else: | |
if args.fp16: | |
with args.amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
if self.current_step % self.accumulation_steps == 0: | |
if args.deepspeed: | |
model.step() | |
else: | |
optimizer.step() | |
scheduler.step() | |
model.zero_grad() | |
if self.current_step % self.report_steps == 0 and \ | |
(not self.dist_train or (self.dist_train and rank == 0)): | |
self.report_and_reset_stats() | |
self.start_time = time.time() | |
if args.deepspeed: | |
if self.current_step % self.save_checkpoint_steps == 0: | |
model.save_checkpoint(self.output_model_path, str(self.current_step)) | |
if loss.item() < self.best_loss: | |
self.best_loss = loss.item() | |
# model.save_checkpoint(self.output_model_path, "-best") | |
else: | |
if self.current_step % self.save_checkpoint_steps == 0 and \ | |
(not self.dist_train or (self.dist_train and rank == 0)): | |
save_model(model, self.output_model_path + "-" + str(self.current_step)) | |
if loss.item() < self.best_loss: | |
self.best_loss = loss.item() | |
# print("save best model! loss:" + str(self.best_loss)) | |
# save_model(model, self.output_model_path + "-best") | |
self.current_step += 1 | |
class MlmTrainer(Trainer): | |
def __init__(self, args): | |
super(MlmTrainer, self).__init__(args) | |
self.total_correct = 0.0 | |
self.total_denominator = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt, seg = batch | |
loss_info = model(src, tgt, seg) | |
loss, correct, denominator = loss_info | |
self.total_loss += loss.item() | |
self.total_correct += correct.item() | |
self.total_denominator += denominator.item() | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| acc: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_correct / self.total_denominator)) | |
self.total_loss = 0.0 | |
self.total_correct = 0.0 | |
self.total_denominator = 0.0 | |
class BertTrainer(Trainer): | |
def __init__(self, args): | |
super(BertTrainer, self).__init__(args) | |
self.total_loss_sp = 0.0 | |
self.total_correct_sp = 0.0 | |
self.total_instances = 0.0 | |
self.total_loss_mlm = 0.0 | |
self.total_correct_mlm = 0.0 | |
self.total_denominator = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt_mlm, tgt_sp, seg = batch | |
tgt = {"mlm": tgt_mlm, "sp": tgt_sp} | |
loss_info = model(src, tgt, seg) | |
loss_mlm, correct_mlm, denominator = loss_info["mlm"] | |
loss_sp, correct_sp = loss_info["sp"] | |
loss = loss_mlm + loss_sp | |
self.total_loss += loss.item() | |
self.total_loss_mlm += loss_mlm.item() | |
self.total_loss_sp += loss_sp.item() | |
self.total_correct_mlm += correct_mlm.item() | |
self.total_correct_sp += correct_sp.item() | |
self.total_denominator += denominator.item() | |
self.total_instances += src.size(0) | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| loss_mlm: {:3.3f}" | |
"| loss_sp: {:3.3f}" | |
"| acc_mlm: {:3.3f}" | |
"| acc_sp: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_loss_mlm / self.report_steps, | |
self.total_loss_sp / self.report_steps, | |
self.total_correct_mlm / self.total_denominator, | |
self.total_correct_sp / self.total_instances)) | |
self.total_loss, self.total_loss_mlm, self.total_loss_sp = 0.0, 0.0, 0.0 | |
self.total_correct_mlm, self.total_denominator = 0.0, 0.0 | |
self.total_correct_sp, self.total_instances = 0.0, 0.0 | |
class AlbertTrainer(BertTrainer): | |
pass | |
class LmTrainer(MlmTrainer): | |
pass | |
class BilmTrainer(Trainer): | |
def __init__(self, args): | |
super(BilmTrainer, self).__init__(args) | |
self.total_loss_forward, self.total_loss_backward = 0.0, 0.0 | |
self.total_correct_forward, self.total_correct_backward = 0.0, 0.0 | |
self.total_denominator = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt_forward, tgt_backward, seg = batch | |
loss_info = model(src, (tgt_forward, tgt_backward), seg) | |
loss_forward, loss_backward, correct_forward, correct_backward, denominator = loss_info | |
loss = loss_forward + loss_backward | |
self.total_loss += loss.item() | |
self.total_loss_forward += loss_forward.item() | |
self.total_loss_backward += loss_backward.item() | |
self.total_correct_forward += correct_forward.item() | |
self.total_correct_backward += correct_backward.item() | |
self.total_denominator += denominator.item() | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| loss_forward {:3.3f}" | |
"| loss_backward {:3.3f}" | |
"| acc_forward: {:3.3f}" | |
"| acc_backward: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_loss_forward / self.report_steps, | |
self.total_loss_backward / self.report_steps, | |
self.total_correct_forward / self.total_denominator, | |
self.total_correct_backward / self.total_denominator)) | |
self.total_loss, self.total_loss_forward, self.total_loss_backward = 0.0, 0.0, 0.0 | |
self.total_correct_forward, self.total_correct_backward, self.total_denominator = 0.0, 0.0, 0.0 | |
class ClsTrainer(Trainer): | |
def __init__(self, args): | |
super(ClsTrainer, self).__init__(args) | |
self.total_correct = 0.0 | |
self.total_instances = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt, seg = batch | |
loss_info = model(src, tgt, seg) | |
loss, correct = loss_info | |
self.total_loss += loss.item() | |
self.total_correct += correct.item() | |
self.total_instances += src.size(0) | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| acc: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_correct / self.total_instances)) | |
self.total_loss = 0.0 | |
self.total_correct = 0.0 | |
self.total_instances = 0.0 | |
class MtTrainer(Trainer): | |
def __init__(self, args): | |
super(MtTrainer, self).__init__(args) | |
self.total_correct = 0.0 | |
self.total_denominator = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt_out, seg, tgt_in, tgt_seg = batch | |
loss_info = model(src, tgt_out, seg, tgt_in, tgt_seg) | |
loss, correct, denominator = loss_info | |
self.total_loss += loss.item() | |
self.total_correct += correct.item() | |
self.total_denominator += denominator.item() | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| acc: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_correct / self.total_denominator)) | |
self.total_loss = 0.0 | |
self.total_correct = 0.0 | |
self.total_denominator = 0.0 | |
class ClsMlmTrainer(Trainer): | |
def __init__(self, args): | |
super(ClsMlmTrainer, self).__init__(args) | |
self.total_loss_cls = 0.0 | |
self.total_correct_cls = 0.0 | |
self.total_instances = 0.0 | |
self.total_loss_mlm = 0.0 | |
self.total_correct_mlm = 0.0 | |
self.total_denominator = 0.0 | |
def forward_propagation(self, batch, model): | |
src, tgt_mlm, tgt_cls, seg = batch | |
tgt = {"mlm": tgt_mlm, "cls": tgt_cls} | |
loss_info = model(src, tgt, seg) | |
loss_mlm, correct_mlm, denominator = loss_info["mlm"] | |
loss_cls, correct_cls = loss_info["cls"] | |
loss = loss_mlm + loss_cls | |
self.total_loss += loss.item() | |
self.total_loss_mlm += loss_mlm.item() | |
self.total_loss_cls += loss_cls.item() | |
self.total_correct_mlm += correct_mlm.item() | |
self.total_correct_cls += correct_cls.item() | |
self.total_denominator += denominator.item() | |
self.total_instances += src.size(0) | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| loss_mlm: {:3.3f}" | |
"| loss_cls: {:3.3f}" | |
"| acc_mlm: {:3.3f}" | |
"| acc_cls: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_loss_mlm / self.report_steps, | |
self.total_loss_cls / self.report_steps, | |
self.total_correct_mlm / self.total_denominator, | |
self.total_correct_cls / self.total_instances)) | |
self.total_loss, self.total_loss_mlm, self.total_loss_cls = 0.0, 0.0, 0.0 | |
self.total_correct_mlm, self.total_denominator = 0.0, 0.0 | |
self.total_correct_cls, self.total_instances = 0.0, 0.0 | |
class T5Trainer(MtTrainer): | |
pass | |
class GsgTrainer(MtTrainer): | |
pass | |
class BartTrainer(MtTrainer): | |
pass | |
class PrefixlmTrainer(MlmTrainer): | |
pass | |
class VitTrainer(ClsTrainer): | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
self.logger.info("| {:8d}/{:8d} steps" | |
"| {:8.2f} patches/s" | |
"| loss {:7.2f}" | |
"| acc: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_correct / self.total_instances)) | |
self.total_loss = 0.0 | |
self.total_correct = 0.0 | |
self.total_instances = 0.0 | |
class ViltTrainer(BertTrainer): | |
def forward_propagation(self, batch, model): | |
src_text, src_image, tgt_mlm, tgt_match, seg = batch | |
tgt = {"mlm": tgt_mlm, "sp": tgt_match} | |
loss_info = model((src_text, src_image), tgt, seg) | |
loss_mlm, correct_mlm, denominator = loss_info["mlm"] | |
loss_match, correct_match = loss_info["sp"] | |
loss = loss_mlm + loss_match | |
self.total_loss += loss.item() | |
self.total_loss_mlm += loss_mlm.item() | |
self.total_loss_sp += loss_match.item() | |
self.total_correct_mlm += correct_mlm.item() | |
self.total_correct_sp += correct_match.item() | |
self.total_denominator += denominator.item() | |
self.total_instances += src_text.size(0) | |
loss = loss / self.accumulation_steps | |
return loss | |
def report_and_reset_stats(self): | |
done_tokens = self.batch_size * self.seq_length * self.report_steps | |
if self.dist_train: | |
done_tokens *= self.world_size | |
print("| {:8d}/{:8d} steps" | |
"| {:8.2f} tokens/s" | |
"| loss {:7.2f}" | |
"| loss_mlm: {:3.3f}" | |
"| loss_match: {:3.3f}" | |
"| acc_mlm: {:3.3f}" | |
"| acc_match: {:3.3f}".format( | |
self.current_step, | |
self.total_steps, | |
done_tokens / (time.time() - self.start_time), | |
self.total_loss / self.report_steps, | |
self.total_loss_mlm / self.report_steps, | |
self.total_loss_sp / self.report_steps, | |
self.total_correct_mlm / self.total_denominator, | |
self.total_correct_sp / self.total_denominator)) | |
self.total_loss, self.total_loss_mlm, self.total_loss_sp = 0.0, 0.0, 0.0 | |
self.total_correct_mlm, self.total_denominator = 0.0, 0.0 | |
self.total_correct_sp, self.total_instances = 0.0, 0.0 | |
class ClipTrainer(ClsTrainer): | |
def forward_propagation(self, batch, model): | |
src_text, src_img, seg_text, seg_img = batch | |
loss_info = model((src_text, src_img), None, (seg_text, seg_img)) | |
loss, correct = loss_info | |
self.total_loss += loss.item() | |
self.total_correct += correct.item() | |
self.total_instances += src_text.size(0) | |
loss = loss / self.accumulation_steps | |
return loss | |
class S2tTrainer(MtTrainer): | |
pass | |
class BeitTrainer(MlmTrainer): | |
def forward_propagation(self, batch, model): | |
src, tgt, seg, mask = batch | |
loss_info = model((src, mask), tgt, seg) | |
loss, correct, denominator = loss_info | |
self.total_loss += loss.item() | |
self.total_correct += correct.item() | |
self.total_denominator += denominator.item() | |
loss = loss / self.accumulation_steps | |
return loss | |
class DalleTrainer(LmTrainer): | |
pass | |
str2trainer = {"bert": BertTrainer, "mlm": MlmTrainer, "lm": LmTrainer, | |
"albert": AlbertTrainer, "bilm": BilmTrainer, "cls": ClsTrainer, | |
"mt": MtTrainer, "t5": T5Trainer, "gsg": GsgTrainer, | |
"bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer, | |
"vit": VitTrainer, "vilt": ViltTrainer, "clip": ClipTrainer, "s2t": S2tTrainer, | |
"beit": BeitTrainer, "dalle": DalleTrainer} | |
def worker(proc_id, gpu_ranks, args, model_for_training, model_for_dataloader=None): | |
""" | |
Args: | |
proc_id: The id of GPU for single GPU mode; | |
The id of process (and GPU) for multiprocessing distributed mode. | |
gpu_ranks: List of ranks of each process. | |
""" | |
set_seed(args.seed) | |
# Get logger | |
args.logger = init_logger(args) | |
if args.deepspeed: | |
import deepspeed | |
deepspeed.init_distributed(dist_backend=args.backend) | |
rank = dist.get_rank() | |
gpu_id = proc_id | |
elif args.dist_train: | |
rank = gpu_ranks[proc_id] | |
gpu_id = proc_id | |
elif args.single_gpu: | |
rank = None | |
gpu_id = proc_id | |
else: | |
rank = None | |
gpu_id = None | |
# Build optimizer. | |
param_optimizer = list(model_for_training.named_parameters()) | |
no_decay = ["bias", "gamma", "beta"] | |
optimizer_grouped_parameters = [ | |
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, | |
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} | |
] | |
if args.optimizer in ["adamw"]: | |
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) | |
else: | |
custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False) | |
if args.scheduler in ["constant"]: | |
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer) | |
elif args.scheduler in ["constant_with_warmup"]: | |
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup) | |
elif args.scheduler in ["tri_stage"]: | |
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.decay, args.total_steps) | |
else: | |
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps) | |
if args.deepspeed: | |
model_for_training, optimizer, _, scheduler = deepspeed.initialize( | |
model=model_for_training, | |
model_parameters=optimizer_grouped_parameters, | |
args=args, | |
optimizer=custom_optimizer, | |
lr_scheduler=custom_scheduler, | |
mpu=None, | |
dist_init_required=False) | |
else: | |
if gpu_id is not None: | |
model_for_training.cuda(gpu_id) | |
if model_for_dataloader is not None: | |
model_for_dataloader.cuda(gpu_id) | |
optimizer = custom_optimizer | |
scheduler = custom_scheduler | |
if args.fp16: | |
try: | |
from apex import amp | |
except ImportError: | |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |
model_for_training, optimizer = amp.initialize(model_for_training, optimizer, opt_level=args.fp16_opt_level) | |
args.amp = amp | |
if args.dist_train: | |
# Initialize multiprocessing distributed training environment. | |
dist.init_process_group(backend=args.backend, | |
init_method=args.master_ip, | |
world_size=args.world_size, | |
rank=rank) | |
model_for_training = DistributedDataParallel(model_for_training, device_ids=[gpu_id], find_unused_parameters=True) | |
if model_for_dataloader is not None: | |
model_for_dataloader = DistributedDataParallel(model_for_dataloader, device_ids=[gpu_id], find_unused_parameters=False) | |
args.logger.info("Worker %d is training ... " % rank) | |
else: | |
args.logger.info("Worker is training ...") | |
if args.dist_train: | |
if model_for_dataloader is not None: | |
model_for_dataloader = model_for_dataloader.module | |
train_loader = str2dataloader[args.data_processor](args, args.dataset_path, args.batch_size, rank, args.world_size, gpu_id, True, model_for_dataloader) | |
else: | |
train_loader = str2dataloader[args.data_processor](args, args.dataset_path, args.batch_size, 0, 1, gpu_id, True, model_for_dataloader) | |
trainer = str2trainer[args.data_processor](args) | |
trainer.train(args, gpu_id, rank, train_loader, model_for_training, optimizer, scheduler) | |