|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pretrain utilities.""" |
|
from datetime import datetime |
|
from functools import partial |
|
from collections import defaultdict |
|
|
|
import math |
|
import sys |
|
from contextlib import nullcontext |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import deepspeed |
|
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler |
|
import numpy as np |
|
|
|
from megatron.utils import ( |
|
Timers, |
|
init_wandb, |
|
get_ltor_masks_and_position_ids, |
|
reduce_losses, |
|
) |
|
|
|
from megatron import print_rank_0, mpu |
|
from megatron.model import ( |
|
GPT2ModelPipe, |
|
SoftEmbedding, |
|
get_params_for_weight_decay_optimization, |
|
mark_norms_for_sequence_parallel_grad_sync, |
|
) |
|
from megatron.mpu.mappings import gather_from_model_parallel_region |
|
from megatron.checkpointing import load_checkpoint, save_checkpoint |
|
from megatron.data.data_utils import ( |
|
build_train_valid_test_data_loaders, |
|
shift_and_wrap_data_loaders, |
|
) |
|
from megatron.initialize import initialize_megatron |
|
from megatron.learning_rates import AnnealingLR |
|
from megatron.logging import tb_wandb_log, training_log |
|
from megatron.utils import ( |
|
OverflowMonitor, |
|
get_noise_scale_logger, |
|
get_total_params, |
|
CharCounter, |
|
) |
|
from megatron.model.gpt2_model import cross_entropy |
|
from megatron.mpu import vocab_parallel_cross_entropy |
|
|
|
from pickle import dump |
|
import os |
|
|
|
|
|
def mup_weights_reinit(neox_args, model): |
|
def has_method(o, name): |
|
return callable(getattr(o, name, None)) |
|
|
|
for layer in model.modules(): |
|
|
|
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters: |
|
layer._rescale_parameters() |
|
|
|
if has_method(layer, "mup_reinitialize_weights"): |
|
layer.mup_reinitialize_weights(neox_args) |
|
|
|
|
|
def save_base_shapes(neox_args, base_shapes, use_cache): |
|
|
|
|
|
neox_args.use_mup = False |
|
|
|
base_model = GPT2ModelPipe( |
|
neox_args=neox_args, |
|
num_tokentypes=0, |
|
parallel_output=True if neox_args.train_impl != "rm" else False, |
|
topology=mpu.get_topology(), |
|
use_cache=use_cache, |
|
) |
|
|
|
if not neox_args.is_pipe_parallel: |
|
base_model = base_model.to_sequential() |
|
|
|
try: |
|
import mup |
|
except ModuleNotFoundError: |
|
print("Please install mup https://github.com/microsoft/mup") |
|
raise Exception |
|
|
|
base_shapes = mup.get_shapes(base_model) |
|
|
|
del base_model |
|
|
|
old_hidden_size = neox_args.hidden_size |
|
neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale |
|
|
|
delta_model = GPT2ModelPipe( |
|
neox_args=neox_args, |
|
num_tokentypes=0, |
|
parallel_output=True if neox_args.train_impl != "rm" else False, |
|
topology=mpu.get_topology(), |
|
use_cache=use_cache, |
|
) |
|
|
|
if not neox_args.is_pipe_parallel: |
|
delta_model = delta_model.to_sequential() |
|
|
|
delta_shapes = mup.get_shapes(delta_model) |
|
|
|
|
|
neox_args.use_mup = True |
|
neox_args.hidden_size = old_hidden_size |
|
|
|
save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}" |
|
print(f"saving base shapes at {save_shapes}") |
|
mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes) |
|
print(f"base shapes saved...exiting") |
|
sys.exit(1) |
|
|
|
|
|
def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator): |
|
from megatron.mup_substitute import get_coord_data |
|
from mup.coord_check import plot_coord_data |
|
|
|
def lazy_model(hidden_size): |
|
def gen(): |
|
old_hidden_size = neox_args.hidden_size |
|
neox_args.hidden_size = hidden_size |
|
|
|
model, optimizer, _, _ = setup_model_and_optimizer( |
|
neox_args=neox_args, use_cache=False |
|
) |
|
|
|
neox_args.hidden_size = old_hidden_size |
|
|
|
return model |
|
|
|
return gen |
|
|
|
models = {} |
|
|
|
|
|
for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)): |
|
models[hidden_size] = lazy_model(hidden_size) |
|
|
|
neox_args.use_mup = True |
|
df_up = get_coord_data( |
|
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True |
|
) |
|
neox_args.use_mup = False |
|
df_sp = get_coord_data( |
|
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False |
|
) |
|
|
|
plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg") |
|
plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg") |
|
|
|
print_rank_0("Saved coord check plots... exiting") |
|
sys.exit(1) |
|
|
|
|
|
def update_iterations(neox_args, data_loaders): |
|
""" |
|
Compute the number of train iterations if not specified and num_epochs, updates the neox_args object. |
|
Note that if len(train_dataloader) % gradient_accumulation_steps != 0, this will configure neox |
|
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs |
|
times. |
|
""" |
|
if (not neox_args.do_train) or (neox_args.train_iters is not None): |
|
pass |
|
elif neox_args.train_iters is None and neox_args.train_epochs is None: |
|
print_rank_0( |
|
"ERROR:Failed to specify either train_epochs or train_iters in config file" |
|
) |
|
else: |
|
global_rank = torch.distributed.get_rank() |
|
|
|
if global_rank == 0: |
|
train_dataloader = data_loaders["train"] |
|
train_epochs = neox_args.train_epochs |
|
gradient_accumulation_steps = neox_args.gradient_accumulation_steps |
|
|
|
train_dataloader_len = len(train_dataloader) |
|
train_iterations = ( |
|
train_dataloader_len * train_epochs |
|
) // gradient_accumulation_steps |
|
|
|
train_iters_tensor = torch.cuda.LongTensor([train_iterations]) |
|
else: |
|
train_iters_tensor = torch.cuda.LongTensor([0]) |
|
|
|
torch.distributed.broadcast(train_iters_tensor, src=0) |
|
|
|
neox_args.train_iters = train_iters_tensor[0].item() |
|
|
|
print_rank_0( |
|
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs." |
|
) |
|
|
|
|
|
def pretrain(neox_args): |
|
"""Main training program. |
|
|
|
This function will run the following in the order provided: |
|
1) initialize Megatron. |
|
2) get train/val/test datasets. |
|
3) setup model, optimizer and lr schedule. |
|
4) configure data loading |
|
5) train the model. |
|
|
|
Arguments: |
|
neox_args: an instance of NeoXArgs containing the configuration for pretrain |
|
|
|
""" |
|
|
|
init_wandb(neox_args=neox_args) |
|
timers = Timers( |
|
use_wandb=neox_args.use_wandb, |
|
tensorboard_writer=neox_args.tensorboard_writer, |
|
comet_experiment=neox_args.comet_experiment, |
|
) |
|
|
|
|
|
initialize_megatron(neox_args=neox_args) |
|
|
|
|
|
timers("train/valid/test data loaders").start() |
|
data_loaders = build_train_valid_test_data_loaders(neox_args=neox_args) |
|
update_iterations(neox_args=neox_args, data_loaders=data_loaders) |
|
timers("train/valid/test data loaders").stop() |
|
|
|
|
|
timers("model and optimizer").start() |
|
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( |
|
neox_args=neox_args, use_cache=False, iteration=neox_args.iteration |
|
) |
|
timers("model and optimizer").stop() |
|
|
|
|
|
timers("train/valid/test data iterators").start() |
|
( |
|
train_data_iterator, |
|
valid_data_iterator, |
|
test_data_iterator, |
|
) = shift_and_wrap_data_loaders(neox_args=neox_args, data_loaders=data_loaders) |
|
timers("train/valid/test data iterators").stop() |
|
|
|
if neox_args.use_mup and neox_args.coord_check: |
|
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator) |
|
|
|
|
|
print_rank_0("done with setups ...") |
|
timers.log( |
|
[ |
|
"train/valid/test data loaders", |
|
"model and optimizer", |
|
"train/valid/test data iterators", |
|
] |
|
) |
|
print_rank_0("training ...") |
|
|
|
iteration = neox_args.iteration |
|
|
|
if ( |
|
neox_args.save |
|
and neox_args.extra_save_iters |
|
and 0 in neox_args.extra_save_iters |
|
and iteration == 0 |
|
): |
|
save_checkpoint( |
|
neox_args=neox_args, |
|
iteration=iteration, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
) |
|
|
|
if neox_args.do_train and neox_args.train_iters > 0: |
|
iteration = train( |
|
neox_args=neox_args, |
|
timers=timers, |
|
model=model, |
|
reference_model=reference_model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
train_data_iterator=train_data_iterator, |
|
valid_data_iterator=valid_data_iterator, |
|
) |
|
|
|
if neox_args.do_valid: |
|
prefix = "the end of training for val data" |
|
evaluate_and_print_results( |
|
neox_args=neox_args, |
|
prefix=prefix, |
|
forward_step_func=forward_step, |
|
data_iterator=valid_data_iterator, |
|
model=model, |
|
iteration=iteration, |
|
verbose=False, |
|
timers=timers, |
|
reference_model=reference_model, |
|
) |
|
|
|
if neox_args.save and iteration != 0: |
|
save_checkpoint( |
|
neox_args=neox_args, |
|
iteration=iteration, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
) |
|
|
|
if neox_args.do_test: |
|
|
|
prefix = "the end of training for test data" |
|
evaluate_and_print_results( |
|
neox_args=neox_args, |
|
prefix=prefix, |
|
forward_step_func=forward_step, |
|
data_iterator=test_data_iterator, |
|
model=model, |
|
iteration=iteration, |
|
verbose=True, |
|
timers=timers, |
|
chart_name="test", |
|
reference_model=reference_model, |
|
) |
|
|
|
|
|
def _get_batch(neox_args, tokenizer, keys, data, datatype, label_mask_zero=False): |
|
"""Support function for get_batch / get_batch pipe (to avoid code repetition)""" |
|
data_b = mpu.broadcast_data(keys, data, datatype) |
|
token_key = keys[0] |
|
label_key = keys[1] if len(keys) > 1 else None |
|
|
|
tokens_ = data_b[token_key].long() |
|
if label_key in data_b: |
|
label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() |
|
labels = torch.where( |
|
data_b[label_key].long() >= 0, |
|
data_b[label_key].long(), |
|
torch.zeros_like(data_b[label_key].long()), |
|
)[:, 1:].contiguous() |
|
else: |
|
label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() |
|
labels = tokens_[:, 1:].contiguous() |
|
if label_mask_zero: |
|
labels = labels * label_mask |
|
tokens = tokens_[:, :-1].contiguous() |
|
|
|
|
|
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( |
|
data=tokens, |
|
eod_token=neox_args.tokenizer.eod, |
|
eod_mask_loss=neox_args.eod_mask_loss, |
|
sliding_window_width=neox_args.sliding_window_width, |
|
) |
|
|
|
|
|
loss_mask = label_mask.to(loss_mask.dtype) * loss_mask |
|
return tokens, labels, loss_mask, attention_mask, position_ids |
|
|
|
|
|
def get_batch(neox_args, data_iterator): |
|
"""Generate a batch""" |
|
|
|
|
|
if neox_args.train_impl in ["normal", "kto"]: |
|
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] |
|
elif neox_args.train_impl in ["dpo", "rm"]: |
|
keys = ( |
|
[["pos", "pos_label"], ["neg", "neg_label"]] |
|
if neox_args.pos_train_label_data_paths |
|
else [["pos"], ["neg"]] |
|
) |
|
datatype = torch.int64 |
|
|
|
|
|
if data_iterator is not None: |
|
data = next(data_iterator) |
|
else: |
|
data = None |
|
if neox_args.train_impl == "normal": |
|
return _get_batch( |
|
neox_args=neox_args, |
|
tokenizer=neox_args.tokenizer, |
|
keys=keys, |
|
data=data, |
|
datatype=datatype, |
|
) |
|
elif neox_args.train_impl == "kto": |
|
assert ( |
|
neox_args.train_micro_batch_size_per_gpu > 1 |
|
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." |
|
tup = _get_batch( |
|
neox_args=neox_args, |
|
tokenizer=neox_args.tokenizer, |
|
keys=keys, |
|
data=data, |
|
datatype=datatype, |
|
) |
|
|
|
|
|
rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][ |
|
:, :-1 |
|
].contiguous() |
|
ref_data = ( |
|
mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous() |
|
if neox_args.precompute_model_name |
|
else None |
|
) |
|
return tup + (rw_data, ref_data) |
|
elif neox_args.train_impl in ["dpo", "rm"]: |
|
pos_tup = _get_batch( |
|
neox_args=neox_args, |
|
tokenizer=neox_args.tokenizer, |
|
keys=keys[0], |
|
data=data, |
|
datatype=datatype, |
|
label_mask_zero=True, |
|
) |
|
neg_tup = _get_batch( |
|
neox_args=neox_args, |
|
tokenizer=neox_args.tokenizer, |
|
keys=keys[1], |
|
data=data, |
|
datatype=datatype, |
|
label_mask_zero=True, |
|
) |
|
if neox_args.precompute_model_name: |
|
ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) |
|
else: |
|
ref_data = {"pos_ref": None} |
|
return [ |
|
torch.cat((pos_item, neg_item), dim=0) |
|
for pos_item, neg_item in zip(pos_tup, neg_tup) |
|
] + [ |
|
torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[ |
|
:, :-1 |
|
].contiguous() |
|
if ref_data["pos_ref"] is not None |
|
else None |
|
] |
|
|
|
|
|
def get_batch_pipe(data, neox_args, curr_scheduler=None): |
|
"""A modification of get_batch() to work with the latest batch instead of an iterator.""" |
|
|
|
assert neox_args.train_impl not in [ |
|
"kto", |
|
"dpo", |
|
"rm", |
|
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" |
|
|
|
|
|
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] |
|
datatype = torch.int64 |
|
|
|
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( |
|
neox_args, neox_args.tokenizer, keys, data, datatype |
|
) |
|
if curr_scheduler is not None: |
|
|
|
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1) |
|
if curriculum_seqlen < tokens.size()[1]: |
|
|
|
|
|
|
|
tokens = tokens[:, :curriculum_seqlen].contiguous() |
|
position_ids = position_ids[:, :curriculum_seqlen].contiguous() |
|
if labels is not None: |
|
labels = labels[:, :curriculum_seqlen].contiguous() |
|
if loss_mask is not None: |
|
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous() |
|
|
|
attention_mask = attention_mask[ |
|
:, :, :curriculum_seqlen, :curriculum_seqlen |
|
].contiguous() |
|
|
|
|
|
return (tokens, position_ids, attention_mask), (labels, loss_mask) |
|
|
|
|
|
def get_batch_sequential(forward_input, neox_args): |
|
"""A modification of get_batch() to work with the latest batch instead of an iterator.""" |
|
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( |
|
data=forward_input[0], |
|
eod_token=neox_args.tokenizer.eod, |
|
eod_mask_loss=neox_args.eod_mask_loss, |
|
) |
|
return (forward_input[0], forward_input[1], attention_mask) |
|
|
|
|
|
def average_losses_across_data_parallel_group(losses): |
|
"""Reduce a tensor of losses across all GPUs.""" |
|
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) |
|
torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group()) |
|
averaged_losses = averaged_losses / torch.distributed.get_world_size( |
|
group=mpu.get_data_parallel_group() |
|
) |
|
|
|
return averaged_losses |
|
|
|
|
|
def mb_moe_loss_func(args, loss_mask, output_tensor=None): |
|
from megatron.model import megablocks_utils |
|
from megatron.model.megablocks_utils import moe |
|
|
|
|
|
|
|
|
|
|
|
loss, loss_dict = (None, {}) |
|
if False: |
|
assert output_tensor is not None |
|
loss, loss_dict = loss_func(loss_mask, output_tensor) |
|
assert loss.numel() == 1 |
|
|
|
|
|
|
|
|
|
if args.checkpoint_activations: |
|
|
|
|
|
load_balancing_loss_data = moe.get_load_balancing_loss() |
|
if args.num_layers * 2 == len(load_balancing_loss_data): |
|
load_balancing_loss_data = load_balancing_loss_data[args.num_layers :] |
|
moe.clear_load_balancing_loss() |
|
for x in load_balancing_loss_data: |
|
moe.save_load_balancing_loss(x) |
|
|
|
|
|
megablocks_args = args = megablocks_utils.as_megablocks_args(args) |
|
lbl = moe.batched_load_balancing_loss(megablocks_args) |
|
moe.clear_load_balancing_loss() |
|
|
|
|
|
|
|
averaged_lbl = average_losses_across_data_parallel_group([lbl]) |
|
loss_dict["load balancing loss"] = averaged_lbl[0] |
|
return averaged_lbl, loss_dict |
|
|
|
|
|
def get_logp(logits, labels, force_fp32=False): |
|
|
|
if force_fp32: |
|
logits = logits.float() |
|
return -vocab_parallel_cross_entropy(logits, labels) |
|
|
|
|
|
def get_pos_neg_logp(logits, labels, force_fp32=False): |
|
|
|
if force_fp32: |
|
logits = logits.float() |
|
return torch.chunk(-vocab_parallel_cross_entropy(logits, labels), 2, 0) |
|
|
|
|
|
def forward_step( |
|
data_iterator, |
|
model, |
|
neox_args, |
|
timers, |
|
return_logits=False, |
|
is_train=False, |
|
reference_model=None, |
|
): |
|
"""Forward step.""" |
|
if neox_args.is_pipe_parallel: |
|
return model.eval_batch(data_iterator, return_logits=return_logits) |
|
|
|
|
|
if neox_args.memory_profiling and neox_args.iteration: |
|
torch.cuda.nvtx.range_push(f"Get batch") |
|
if timers is not None: |
|
timers("batch generator").start() |
|
if neox_args.train_impl == "normal": |
|
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( |
|
neox_args=neox_args, data_iterator=data_iterator |
|
) |
|
elif neox_args.train_impl == "kto": |
|
( |
|
tokens, |
|
labels, |
|
loss_mask, |
|
attention_mask, |
|
position_ids, |
|
rewards, |
|
ref_logp, |
|
) = get_batch(neox_args=neox_args, data_iterator=data_iterator) |
|
if neox_args.train_impl in ["dpo", "rm"]: |
|
tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( |
|
neox_args=neox_args, data_iterator=data_iterator |
|
) |
|
|
|
if timers is not None: |
|
timers("batch generator").stop() |
|
if neox_args.memory_profiling: |
|
torch.cuda.nvtx.range_pop() |
|
|
|
if neox_args.memory_profiling: |
|
torch.cuda.nvtx.range_push(f"Forward pass") |
|
metrics = {} |
|
if neox_args.train_impl == "normal": |
|
|
|
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) |
|
if type(maybe_tuple) is tuple: |
|
outputs, moe_losses = maybe_tuple |
|
else: |
|
outputs = maybe_tuple |
|
moe_losses = [] |
|
if ( |
|
is_train |
|
and neox_args.curriculum_learning |
|
and neox_args.curriculum_seqlen < neox_args.seq_length |
|
): |
|
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() |
|
labels = labels[:, : neox_args.curriculum_seqlen].contiguous() |
|
main_loss = cross_entropy( |
|
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy |
|
) |
|
if neox_args.moe_num_experts > 1: |
|
if neox_args.moe_type == "deepspeed": |
|
moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) |
|
elif neox_args.moe_type == "megablocks": |
|
moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] |
|
else: |
|
raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") |
|
else: |
|
moe_loss = 0.0 |
|
loss = main_loss + moe_loss |
|
elif neox_args.train_impl == "rm": |
|
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) |
|
if type(maybe_tuple) is tuple: |
|
outputs, _ = maybe_tuple |
|
else: |
|
outputs = maybe_tuple |
|
pos, neg = torch.chunk(outputs, 2, 0) |
|
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) |
|
|
|
|
|
|
|
pos_indx = pos_loss_mask.nonzero() |
|
neg_indx = neg_loss_mask.nonzero() |
|
|
|
pos_indx = pos_indx[:, 1].unsqueeze(1) |
|
neg_indx = neg_indx[:, 1].unsqueeze(1) |
|
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) |
|
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) |
|
with torch.no_grad(): |
|
metrics["pos_values"] = pos.clone().detach().mean() |
|
metrics["neg_values"] = neg.clone().detach().mean() |
|
metrics["margin"] = (pos - neg).clone().detach().mean() |
|
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() |
|
loss = (-F.logsigmoid(pos - neg).mean()) + ( |
|
(neox_args.z_loss * (pos**2 + neg**2)).mean() |
|
) |
|
elif neox_args.train_impl == "dpo": |
|
|
|
with torch.inference_mode(): |
|
|
|
token_logp_labels = labels.clone() |
|
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) |
|
if neox_args.dpo_reference_free: |
|
ref_pos = 0 |
|
ref_neg = 0 |
|
elif ref_logp is None: |
|
ref_maybe_tuple = reference_model( |
|
(tokens, position_ids, attention_mask), neox_args=neox_args |
|
) |
|
if type(ref_maybe_tuple) is tuple: |
|
|
|
ref_outputs, _ = ref_maybe_tuple |
|
else: |
|
ref_outputs = ref_maybe_tuple |
|
ref_pos, ref_neg = get_pos_neg_logp( |
|
ref_outputs, token_logp_labels, neox_args.dpo_fp32 |
|
) |
|
else: |
|
ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0) |
|
ref_pos = (ref_pos * pos_loss_mask).sum(-1) |
|
ref_neg = (ref_neg * neg_loss_mask).sum(-1) |
|
chosen_maybe_tuple = model( |
|
(tokens, position_ids, attention_mask), neox_args=neox_args |
|
) |
|
if type(chosen_maybe_tuple) is tuple: |
|
|
|
chosen_outputs, _ = chosen_maybe_tuple |
|
else: |
|
chosen_outputs = chosen_maybe_tuple |
|
chosen_pos, chosen_neg = get_pos_neg_logp( |
|
chosen_outputs, token_logp_labels, neox_args.dpo_fp32 |
|
) |
|
chosen_pos = (chosen_pos * pos_loss_mask).sum(-1) |
|
chosen_neg = (chosen_neg * neg_loss_mask).sum(-1) |
|
with torch.no_grad(): |
|
|
|
if not neox_args.dpo_reference_free: |
|
metrics["ref_neg"] = ref_neg.clone().detach().mean() |
|
metrics["ref_pos"] = ref_pos.clone().detach().mean() |
|
metrics["chosen_neg"] = chosen_neg.clone().detach().mean() |
|
metrics["chosen_pos"] = chosen_pos.clone().detach().mean() |
|
if not neox_args.dpo_reference_free: |
|
chosen_rewards = neox_args.dpo_beta * ( |
|
chosen_pos.clone().detach() - ref_pos.clone().detach() |
|
) |
|
rejected_rewards = neox_args.dpo_beta * ( |
|
chosen_neg.clone().detach() - ref_neg.clone().detach() |
|
) |
|
metrics["chosen_rewards"] = chosen_rewards.mean() |
|
metrics["rejected_rewards"] = rejected_rewards.mean() |
|
reward_acc = (chosen_rewards > rejected_rewards).float() |
|
metrics["reward_acc"] = reward_acc.mean() |
|
metrics["margins"] = (chosen_rewards - rejected_rewards).mean() |
|
pi_logrations = chosen_pos - chosen_neg |
|
ref_logrations = ref_pos - ref_neg |
|
logits = pi_logrations - ref_logrations |
|
loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() |
|
elif neox_args.train_impl == "kto": |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
token_logp_labels = labels.clone() |
|
token_logp_labels[token_logp_labels == -100] = 0 |
|
if ref_logp is None: |
|
|
|
ref_maybe_tuple = reference_model( |
|
(tokens, position_ids, attention_mask), neox_args=neox_args |
|
) |
|
if type(ref_maybe_tuple) is tuple: |
|
|
|
ref_outputs, _ = ref_maybe_tuple |
|
else: |
|
ref_outputs = ref_maybe_tuple |
|
|
|
ref_outputs = gather_from_model_parallel_region(ref_outputs) |
|
|
|
ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32) |
|
else: |
|
print(f"REF LOGP: {ref_logp.clone().detach().mean()}") |
|
ref_logp = ref_logp * loss_mask |
|
scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight |
|
scaling += ( |
|
rewards.sum(-1) < -0.001 |
|
).float() * neox_args.kto_undesirable_weight |
|
pos_mask = (rewards > 0.001).float() |
|
neg_mask = (rewards < -0.001).float() |
|
chosen_maybe_tuple = model( |
|
(tokens, position_ids, attention_mask), neox_args=neox_args |
|
) |
|
if type(chosen_maybe_tuple) is tuple: |
|
|
|
chosen_outputs, _ = chosen_maybe_tuple |
|
else: |
|
chosen_outputs = chosen_maybe_tuple |
|
chosen_outputs = gather_from_model_parallel_region(chosen_outputs) |
|
chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32) |
|
chosen_logp = chosen_logp * loss_mask |
|
with torch.no_grad(): |
|
|
|
metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean() |
|
metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean() |
|
metrics["pos_ref_logp"] = ( |
|
(ref_logp * pos_mask).clone().detach().sum(-1).mean() |
|
) |
|
metrics["neg_ref_logp"] = ( |
|
(ref_logp * neg_mask).clone().detach().sum(-1).mean() |
|
) |
|
metrics["pos_policy_logp"] = ( |
|
(chosen_logp * pos_mask).clone().detach().sum(-1).mean() |
|
) |
|
metrics["neg_policy_logp"] = ( |
|
(chosen_logp * neg_mask).clone().detach().sum(-1).mean() |
|
) |
|
metrics["kl"] = ( |
|
chosen_logp.clone().detach() - ref_logp.clone().detach() |
|
).sum() / loss_mask.sum() |
|
policy_rewards = ( |
|
neox_args.kto_beta |
|
* rewards |
|
* (chosen_logp.clone().detach() - ref_logp.clone().detach()) |
|
) |
|
reward_acc = (policy_rewards.sum(-1) > 0.0).float() |
|
metrics["reward_acc"] = reward_acc.mean() |
|
metrics["policy_rewards"] = policy_rewards.sum() |
|
print(metrics) |
|
pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0) |
|
ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0) |
|
reward1, reward2 = torch.chunk(rewards, 2, 0) |
|
scaling1, scaling2 = torch.chunk(scaling, 2, 0) |
|
kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean() |
|
kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean() |
|
log_ratio1 = pol_logp1 - ref_logp1 |
|
log_ratio2 = pol_logp2 - ref_logp2 |
|
|
|
|
|
loss = ( |
|
0.5 |
|
* scaling1.mean(-1) |
|
* ( |
|
1 |
|
- F.sigmoid( |
|
( |
|
neox_args.kto_beta |
|
* reward1.mean(-1) |
|
* (log_ratio1.sum(-1) - kl2.clone().detach()) |
|
) |
|
) |
|
) |
|
) + ( |
|
0.5 |
|
* scaling2.mean(-1) |
|
* ( |
|
1 |
|
- F.sigmoid( |
|
( |
|
neox_args.kto_beta |
|
* reward2.mean(-1) |
|
* (log_ratio2.sum(-1) - kl1.clone().detach()) |
|
) |
|
) |
|
) |
|
) |
|
|
|
loss = loss.mean() |
|
|
|
if neox_args.memory_profiling: |
|
torch.cuda.nvtx.range_pop() |
|
if return_logits: |
|
return loss, outputs, metrics |
|
return loss, metrics |
|
|
|
|
|
def get_model(neox_args, use_cache=False): |
|
"""Build the model.""" |
|
|
|
|
|
print_rank_0("building GPT2 model ...") |
|
|
|
|
|
|
|
old_use_mup = neox_args.use_mup |
|
neox_args.use_mup = False |
|
|
|
if neox_args.zero_stage in [2, 3]: |
|
if neox_args.pipe_parallel_size == 1: |
|
print_rank_0( |
|
"ZeRO stage 2/3 and the PipelineModule are incompatible, please set 'pipe_parallel_size' to 0 instead" |
|
) |
|
exit() |
|
if neox_args.pipe_parallel_size > 1: |
|
print_rank_0( |
|
"ZeRO stage 2/3 and pipeline paralleism are not supported simultaneously" |
|
) |
|
exit() |
|
if neox_args.model_parallel_size > 1: |
|
print_rank_0( |
|
"ZeRO stage 2/3 and model paralleism are not currently supported simultaneously" |
|
) |
|
exit() |
|
|
|
with deepspeed.zero.Init( |
|
config_dict_or_path=neox_args.deepspeed_config |
|
) if neox_args.zero_stage == 3 else nullcontext() as gs: |
|
model = GPT2ModelPipe( |
|
neox_args=neox_args, |
|
num_tokentypes=0, |
|
parallel_output=True if neox_args.train_impl != "rm" else False, |
|
topology=mpu.get_topology(), |
|
use_cache=use_cache, |
|
) |
|
|
|
|
|
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get( |
|
"enabled", False |
|
): |
|
soft_prompt = SoftEmbedding( |
|
neox_args, |
|
wte=getattr(model, "0").word_embeddings, |
|
n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10), |
|
init_string=neox_args.soft_prompt_tuning.get("init_string", ""), |
|
init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5), |
|
) |
|
model.insert_layers( |
|
layers=soft_prompt, idx=1 |
|
) |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
if not "soft_embedding" in name: |
|
param.requires_grad = False |
|
|
|
if not neox_args.is_pipe_parallel: |
|
|
|
model = model.to_sequential() |
|
|
|
neox_args.use_mup = old_use_mup |
|
|
|
if neox_args.use_mup: |
|
try: |
|
import mup |
|
except ModuleNotFoundError: |
|
print("Please install mup https://github.com/microsoft/mup") |
|
raise Exception |
|
|
|
base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}" |
|
|
|
if neox_args.save_base_shapes: |
|
save_base_shapes(neox_args, base_shapes, use_cache) |
|
|
|
mup.set_base_shapes(model, base_shapes) |
|
|
|
|
|
mup_weights_reinit(neox_args, model) |
|
|
|
if neox_args.deepspeed: |
|
|
|
return model |
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
def get_optimizer(model, neox_args, dummy=False): |
|
"""Set up the optimizer.""" |
|
if neox_args.no_load_optim and neox_args.deepspeed: |
|
|
|
dummy = True |
|
neox_args.optimizer = {"params": {"lr": 0.0}} |
|
neox_args.optimizer_type = "adam" |
|
elif neox_args.no_load_optim: |
|
return None, None |
|
|
|
if neox_args.optimizer is None: |
|
print_rank_0( |
|
f"ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)" |
|
) |
|
exit() |
|
|
|
param_groups = get_params_for_weight_decay_optimization(model, neox_args) |
|
print_rank_0( |
|
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}' |
|
) |
|
|
|
if neox_args.create_moe_param_group: |
|
from deepspeed.moe.utils import ( |
|
is_moe_param, |
|
split_params_into_different_moe_groups_for_optimizer, |
|
) |
|
|
|
param_groups = split_params_into_different_moe_groups_for_optimizer( |
|
param_groups |
|
) |
|
|
|
|
|
for param_group in param_groups: |
|
for param in param_group["params"]: |
|
if not hasattr(param, "model_parallel"): |
|
param.model_parallel = False |
|
|
|
|
|
_param_groups = [] |
|
for param_group in param_groups: |
|
trainable_params = [p for p in param_group["params"] if p.requires_grad] |
|
if dummy: |
|
trainable_params = [trainable_params[0]] |
|
param_group["params"] = trainable_params |
|
_param_groups.append(param_group) |
|
if dummy: |
|
|
|
break |
|
param_groups = _param_groups |
|
|
|
|
|
assert not neox_args.use_mup or ( |
|
neox_args.optimizer_type.lower() == "adam" |
|
or neox_args.optimizer_type.lower() == "sgd" |
|
), f"If use_mup == True, you must specify either the adam or sgd optimizers. You passed: {neox_args.optimizer_type.lower()}" |
|
|
|
if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]: |
|
if neox_args.optimizer == "cpu_torch_adam": |
|
cpu_adam_optimizer = torch.optim.Adam |
|
else: |
|
from deepspeed.ops.adam import DeepSpeedCPUAdam |
|
|
|
cpu_adam_optimizer = DeepSpeedCPUAdam |
|
optimizer = cpu_adam_optimizer( |
|
param_groups, |
|
weight_decay=neox_args.weight_decay, |
|
**neox_args.optimizer["params"], |
|
) |
|
elif neox_args.optimizer_type.lower() == "onebitadam": |
|
assert neox_args.deepspeed |
|
optimizer = None |
|
|
|
elif neox_args.optimizer_type.lower() == "sm3": |
|
from .optimizers import SM3 |
|
|
|
optimizer = SM3(param_groups, **neox_args.optimizer["params"]) |
|
elif neox_args.optimizer_type.lower() == "madgrad_wd": |
|
from .optimizers import madgrad_wd |
|
|
|
optimizer = madgrad_wd( |
|
param_groups, |
|
weight_decay=neox_args.weight_decay, |
|
**neox_args.optimizer["params"], |
|
) |
|
elif neox_args.optimizer_type.lower() == "lion": |
|
|
|
if neox_args.zero_optimization["stage"] != 0: |
|
from deepspeed.ops.lion import FusedLion |
|
|
|
lion_optimizer = FusedLion |
|
|
|
else: |
|
from .optimizers import Lion |
|
|
|
lion_optimizer = Lion |
|
|
|
optimizer = lion_optimizer( |
|
param_groups, |
|
weight_decay=neox_args.weight_decay, |
|
**neox_args.optimizer["params"], |
|
) |
|
elif neox_args.optimizer_type.lower() == "adam": |
|
|
|
if neox_args.use_mup: |
|
try: |
|
from mup import MuAdam |
|
|
|
adam_optimizer = MuAdam |
|
except ModuleNotFoundError: |
|
print("Please install mup https://github.com/microsoft/mup") |
|
raise Exception |
|
else: |
|
if neox_args.use_bnb_optimizer: |
|
try: |
|
import bitsandbytes as bnb |
|
|
|
adam_optimizer = bnb.optim.Adam8bit |
|
except ModuleNotFoundError: |
|
print( |
|
"Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes." |
|
) |
|
raise Exception |
|
else: |
|
try: |
|
|
|
from apex.optimizers import FusedAdam as Adam |
|
except ImportError: |
|
|
|
print( |
|
"WARNING: APEX not installed - defaulting to deepspeed's fused adam" |
|
) |
|
from deepspeed.ops.adam import FusedAdam as Adam |
|
adam_optimizer = Adam |
|
optimizer = adam_optimizer( |
|
param_groups, |
|
weight_decay=neox_args.weight_decay, |
|
**neox_args.optimizer["params"], |
|
) |
|
elif neox_args.optimizer_type.lower() == "sgd": |
|
try: |
|
from mup import MuSGD |
|
except ModuleNotFoundError: |
|
print("Please install mup https://github.com/microsoft/mup") |
|
raise Exception |
|
optimizer = MuSGD( |
|
param_groups, |
|
weight_decay=neox_args.weight_decay, |
|
**neox_args.optimizer["params"], |
|
) |
|
else: |
|
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized") |
|
|
|
if neox_args.deepspeed: |
|
|
|
return optimizer, param_groups |
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
def get_learning_rate_scheduler(optimizer, neox_args): |
|
"""Build the learning rate scheduler.""" |
|
if (neox_args.no_load_optim) and not neox_args.deepspeed: |
|
|
|
return None |
|
if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": |
|
print_rank_0( |
|
"WARNING: onebitadam requires the lr scheduler be built by deepspeed - " |
|
"Make sure one is added to your deepspeed config" |
|
) |
|
return None |
|
|
|
|
|
if neox_args.lr_decay_iters is not None: |
|
num_iters = neox_args.lr_decay_iters |
|
elif neox_args.lr_decay_fraction is not None: |
|
num_iters = math.floor(neox_args.train_iters * neox_args.lr_decay_fraction) |
|
else: |
|
num_iters = neox_args.train_iters |
|
num_iters = max(1, num_iters) |
|
init_step = 0 |
|
warmup_iter = neox_args.warmup * num_iters |
|
lr_scheduler = AnnealingLR( |
|
optimizer, |
|
start_lr=neox_args.lr, |
|
warmup_iter=warmup_iter, |
|
total_iters=num_iters, |
|
decay_style=neox_args.lr_decay_style, |
|
last_iter=init_step, |
|
min_lr=neox_args.min_lr, |
|
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler, |
|
override_lr_scheduler=neox_args.override_lr_scheduler, |
|
use_mup=neox_args.use_mup, |
|
) |
|
|
|
return lr_scheduler |
|
|
|
|
|
def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): |
|
"""Setup memory profiler""" |
|
if neox_args.memory_profiling: |
|
torch.cuda.memory._record_memory_history( |
|
True, |
|
|
|
trace_alloc_max_entries=100000, |
|
trace_alloc_record_context=True, |
|
) |
|
|
|
"""Setup model and optimizer.""" |
|
needs_reference_model = ( |
|
(neox_args.train_impl == "dpo") |
|
and (neox_args.precompute_model_name is None) |
|
and (not neox_args.dpo_reference_free) |
|
) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) |
|
model = get_model(neox_args=neox_args, use_cache=use_cache) |
|
if needs_reference_model: |
|
reference_model = get_model(neox_args=neox_args, use_cache=use_cache) |
|
else: |
|
reference_model = None |
|
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) |
|
lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) |
|
if neox_args.deepspeed and needs_reference_model: |
|
|
|
ref_optimizer, ref_param_groups = get_optimizer( |
|
model=reference_model, neox_args=neox_args, dummy=True |
|
) |
|
ref_lr_scheduler = get_learning_rate_scheduler( |
|
optimizer=ref_optimizer, neox_args=neox_args |
|
) |
|
else: |
|
ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None |
|
if neox_args.deepspeed: |
|
print_rank_0("DeepSpeed is enabled.") |
|
_model_params = param_groups if optimizer is None else None |
|
_lr_scheduler = lr_scheduler |
|
|
|
model, optimizer, _, lr_scheduler = deepspeed.initialize( |
|
model=model, |
|
optimizer=optimizer, |
|
args=neox_args, |
|
lr_scheduler=_lr_scheduler, |
|
dist_init_required=False, |
|
model_parameters=_model_params, |
|
|
|
|
|
mpu=mpu if not neox_args.is_pipe_parallel else None, |
|
) |
|
if needs_reference_model: |
|
reference_model, _, _, _ = deepspeed.initialize( |
|
model=reference_model, |
|
optimizer=ref_optimizer, |
|
args=neox_args, |
|
lr_scheduler=ref_lr_scheduler, |
|
dist_init_required=False, |
|
model_parameters=ref_param_groups, |
|
mpu=mpu if not neox_args.is_pipe_parallel else None, |
|
) |
|
mark_norms_for_sequence_parallel_grad_sync(model, neox_args) |
|
if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": |
|
|
|
model.has_moe_layers = True |
|
model.total_params = get_total_params(model.module) |
|
print_rank_0(f' > total params: {"{:,}".format(model.total_params)}') |
|
|
|
if neox_args.is_pipe_parallel: |
|
model.set_has_attention_mask(True) |
|
if neox_args.curriculum_learning: |
|
curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning) |
|
if iteration is not None and iteration > 0: |
|
curr_scheduler.update_difficulty(iteration) |
|
else: |
|
curr_scheduler = None |
|
model.set_batch_fn( |
|
partial( |
|
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler |
|
) |
|
) |
|
else: |
|
model.module.set_batch_fn( |
|
partial(get_batch_sequential, neox_args=neox_args) |
|
) |
|
|
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
if neox_args.load is not None: |
|
neox_args.iteration = load_checkpoint( |
|
neox_args=neox_args, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
iteration=iteration, |
|
) |
|
if needs_reference_model: |
|
_ = load_checkpoint( |
|
neox_args=neox_args, |
|
model=reference_model, |
|
optimizer=ref_optimizer, |
|
lr_scheduler=ref_lr_scheduler, |
|
iteration=iteration, |
|
) |
|
reference_model.eval() |
|
print_rank_0( |
|
f"Loading checkpoint and starting from iteration {neox_args.iteration}" |
|
) |
|
else: |
|
neox_args.iteration = 0 |
|
|
|
|
|
|
|
if lr_scheduler is not None: |
|
lr_scheduler.optimizer = model.optimizer |
|
|
|
return model, optimizer, lr_scheduler, reference_model |
|
|
|
|
|
def backward_step(neox_args, timers, optimizer, model, loss): |
|
"""Backward step.""" |
|
|
|
|
|
timers("backward-backward").start() |
|
if neox_args.deepspeed: |
|
model.backward(loss) |
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
timers("backward-backward").stop() |
|
|
|
if neox_args.deepspeed: |
|
|
|
|
|
timers("backward-allreduce").reset() |
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
def train_step( |
|
neox_args, |
|
timers, |
|
data_iterator, |
|
model, |
|
optimizer, |
|
lr_scheduler, |
|
reference_model=None, |
|
): |
|
"""Single training step.""" |
|
|
|
|
|
if neox_args.is_pipe_parallel: |
|
reduced_loss = train_step_pipe( |
|
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator |
|
) |
|
reduce_metrics = reduced_loss |
|
if ( |
|
neox_args.memory_profiling |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
and torch.distributed.get_rank() == 0 |
|
): |
|
save_snapshot(neox_args) |
|
else: |
|
losses = [] |
|
metric_dicts = defaultdict(list) |
|
for _ in range(neox_args.gradient_accumulation_steps): |
|
|
|
timers("forward").start() |
|
loss, metric_dict = forward_step( |
|
neox_args=neox_args, |
|
timers=timers, |
|
data_iterator=data_iterator, |
|
model=model, |
|
is_train=True, |
|
reference_model=reference_model, |
|
) |
|
timers("forward").stop() |
|
losses.append(loss) |
|
for key in metric_dict.keys(): |
|
metric_dicts[key].append(metric_dict[key]) |
|
|
|
if ( |
|
neox_args.profile |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
): |
|
torch.cuda.nvtx.range_push(f"Backward pass") |
|
timers("backward").start() |
|
backward_step( |
|
neox_args=neox_args, |
|
timers=timers, |
|
optimizer=optimizer, |
|
model=model, |
|
loss=loss, |
|
) |
|
timers("backward").stop() |
|
if ( |
|
neox_args.profile |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
): |
|
torch.cuda.nvtx.range_pop() |
|
|
|
if ( |
|
neox_args.profile |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
): |
|
torch.cuda.nvtx.range_push(f"Optimizer step") |
|
|
|
timers("optimizer").start() |
|
if neox_args.deepspeed: |
|
model.step() |
|
else: |
|
raise ValueError("Must be using deepspeed to run neox") |
|
timers("optimizer").stop() |
|
if ( |
|
neox_args.profile |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
): |
|
torch.cuda.nvtx.range_pop() |
|
if ( |
|
neox_args.profile |
|
and neox_args.iteration >= neox_args.profile_step_start |
|
and neox_args.iteration <= neox_args.profile_step_stop |
|
and torch.distributed.get_rank() == 0 |
|
): |
|
save_snapshot(neox_args) |
|
|
|
reduce_metrics = { |
|
key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys() |
|
} |
|
reduce_metrics["lm_loss"] = reduce_losses(losses).mean() |
|
|
|
if neox_args.precision == "fp16" and model.optimizer.overflow: |
|
skipped_iter = 1 |
|
else: |
|
skipped_iter = 0 |
|
|
|
collect_loss_for_unit_test(reduce_metrics["lm_loss"]) |
|
return reduce_metrics, skipped_iter |
|
|
|
|
|
def train_step_pipe(neox_args, timers, model, data_iterator): |
|
"""Single training step with DeepSpeed's pipeline parallel engine.""" |
|
|
|
assert neox_args.deepspeed |
|
loss = model.train_batch(data_iter=data_iterator) |
|
loss_dict = {"lm_loss": loss} |
|
|
|
for t in [ |
|
"forward", |
|
"backward", |
|
"allreduce", |
|
"optimizer", |
|
"batch generator", |
|
"data loader", |
|
]: |
|
timers(t).reset() |
|
return loss_dict |
|
|
|
|
|
def is_save_iter(neox_args, iteration): |
|
if neox_args.extra_save_iters and iteration in neox_args.extra_save_iters: |
|
return True |
|
|
|
if neox_args.checkpoint_factor: |
|
if neox_args.checkpoint_scale == "linear": |
|
assert float( |
|
neox_args.checkpoint_factor |
|
).is_integer(), "checkpoint_factor must be a whole number when using linear checkpoint_scale" |
|
return iteration % neox_args.checkpoint_factor == 0 |
|
elif neox_args.checkpoint_scale == "log": |
|
|
|
assert neox_args.checkpoint_factor > 1 |
|
power = 1 |
|
while power < iteration + 1: |
|
if int(power) == iteration: |
|
return True |
|
power *= neox_args.checkpoint_factor |
|
return False |
|
|
|
return False |
|
|
|
|
|
def train( |
|
neox_args, |
|
timers, |
|
model, |
|
reference_model, |
|
optimizer, |
|
lr_scheduler, |
|
train_data_iterator, |
|
valid_data_iterator, |
|
): |
|
"""Train the model function.""" |
|
|
|
|
|
model.train() |
|
|
|
|
|
total_loss_dict = {} |
|
|
|
|
|
iteration = neox_args.iteration |
|
|
|
timers("interval time").start() |
|
report_memory_flag = True |
|
|
|
|
|
noise_scale_logger = get_noise_scale_logger(neox_args) |
|
|
|
|
|
overflow_monitor = OverflowMonitor(optimizer) |
|
|
|
if neox_args.profile: |
|
schedule = torch.profiler.schedule( |
|
wait=neox_args.profile_step_start, |
|
warmup=1, |
|
active=neox_args.profile_step_stop - neox_args.profile_step_start, |
|
) |
|
prof = torch.profiler.profile( |
|
schedule=schedule, |
|
on_trace_ready=torch.profiler.tensorboard_trace_handler( |
|
neox_args.tensorboard_dir |
|
), |
|
record_shapes=True, |
|
profile_memory=True, |
|
with_flops=True, |
|
with_modules=True, |
|
with_stack=True, |
|
) |
|
prof.start() |
|
while iteration < neox_args.train_iters: |
|
if neox_args.profile: |
|
prof.step() |
|
if neox_args.profile and iteration == neox_args.profile_step_start: |
|
torch.cuda.cudart().cudaProfilerStart() |
|
loss_dict, skipped_iter = train_step( |
|
neox_args=neox_args, |
|
timers=timers, |
|
data_iterator=train_data_iterator, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
reference_model=reference_model, |
|
) |
|
if neox_args.profile and iteration == neox_args.profile_step_stop: |
|
torch.cuda.cudart().cudaProfilerStop() |
|
prof.stop() |
|
iteration += 1 |
|
neox_args.iteration = iteration |
|
if neox_args.precision == "fp16": |
|
overflow_monitor.check(skipped_iter) |
|
if neox_args.log_gradient_noise_scale: |
|
noise_scale_logger.update() |
|
|
|
|
|
|
|
if optimizer.param_groups: |
|
lr = optimizer.param_groups[0].get("lr", 0) |
|
else: |
|
lr = 0 |
|
|
|
|
|
report_memory_flag = training_log( |
|
neox_args=neox_args, |
|
timers=timers, |
|
loss_dict=loss_dict, |
|
total_loss_dict=total_loss_dict, |
|
learning_rate=lr, |
|
iteration=iteration, |
|
loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None, |
|
report_memory_flag=report_memory_flag, |
|
skipped_iter=skipped_iter, |
|
model=model, |
|
optimizer=optimizer, |
|
noise_scale_logger=noise_scale_logger, |
|
) |
|
|
|
|
|
if neox_args.save and is_save_iter(neox_args, iteration): |
|
save_checkpoint( |
|
neox_args=neox_args, |
|
iteration=iteration, |
|
model=model, |
|
optimizer=optimizer, |
|
lr_scheduler=lr_scheduler, |
|
) |
|
|
|
if ( |
|
neox_args.eval_interval |
|
and iteration % neox_args.eval_interval == 0 |
|
and neox_args.do_valid |
|
): |
|
prefix = "iteration {}".format(iteration) |
|
evaluate_and_print_results( |
|
neox_args=neox_args, |
|
prefix=prefix, |
|
forward_step_func=forward_step, |
|
data_iterator=valid_data_iterator, |
|
model=model, |
|
iteration=iteration, |
|
verbose=False, |
|
timers=timers, |
|
reference_model=reference_model, |
|
) |
|
|
|
if neox_args.exit_interval and iteration % neox_args.exit_interval == 0: |
|
torch.distributed.barrier() |
|
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
rank = torch.distributed.get_rank() |
|
print_rank_0( |
|
"rank: {} | time: {} | exiting the program at iteration {}".format( |
|
rank, time_str, iteration |
|
) |
|
) |
|
sys.exit() |
|
|
|
return iteration |
|
|
|
|
|
def evaluate( |
|
neox_args, |
|
forward_step_fn, |
|
data_iterator, |
|
model, |
|
verbose=False, |
|
timers=None, |
|
reference_model=None, |
|
): |
|
"""Evaluation. |
|
neox_args: NeoX Arguments |
|
forward_step_fn: function with args `neox_args, timers, |
|
data_iterator & model that will run a forward pass on the model |
|
data_iterator: Iterator that iterates over batches of data. Should return data in the form: |
|
{'text': np.array([tokens], dtype=np.int64)} |
|
where the size of the array is the model's context size + 1 |
|
(`get_batch` transforms it into inputs / labels) |
|
""" |
|
|
|
model.eval() |
|
losses = [] |
|
metric_dicts = defaultdict(list) |
|
if neox_args.char_level_ppl: |
|
data_iterator = CharCounter(data_iterator, neox_args.tokenizer) |
|
|
|
with torch.no_grad(): |
|
iteration = 0 |
|
while iteration < neox_args.eval_iters: |
|
iteration += 1 |
|
if verbose and iteration % neox_args.log_interval == 0: |
|
print_rank_0( |
|
"Evaluating iter {}/{}".format(iteration, neox_args.eval_iters) |
|
) |
|
|
|
|
|
|
|
|
|
for _ in range( |
|
1 |
|
if neox_args.is_pipe_parallel |
|
else neox_args.gradient_accumulation_steps |
|
): |
|
|
|
loss, metric_dict = forward_step_fn( |
|
model=model, |
|
data_iterator=data_iterator, |
|
neox_args=neox_args, |
|
timers=timers, |
|
reference_model=reference_model, |
|
) |
|
losses.append(loss) |
|
for key in metric_dict.keys(): |
|
metric_dicts[key].append(metric_dict[key]) |
|
|
|
|
|
|
|
|
|
if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing: |
|
deepspeed.checkpointing.reset() |
|
|
|
|
|
eval_results = {"lm_loss": reduce_losses(losses).mean().item()} |
|
for key in metric_dicts.keys(): |
|
eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() |
|
eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) |
|
|
|
if neox_args.char_level_ppl: |
|
|
|
|
|
|
|
tokens_per_char = data_iterator.tokens_per_char() |
|
print_rank_0(f"Counting chars took {data_iterator.total_time} seconds") |
|
|
|
data_iterator = data_iterator.data_iterator |
|
eval_results["lm_loss_char_lvl_ppl"] = math.exp( |
|
eval_results["lm_loss"] * tokens_per_char |
|
) |
|
|
|
if neox_args.eval_tasks: |
|
from eval_tasks import run_eval_harness |
|
|
|
eval_results.update( |
|
run_eval_harness( |
|
model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks |
|
).get("results") |
|
) |
|
|
|
model.train() |
|
return eval_results |
|
|
|
|
|
def collect_loss_for_unit_test(lm_ss): |
|
|
|
pass |
|
|
|
|
|
def evaluate_and_print_results( |
|
neox_args, |
|
prefix, |
|
forward_step_func, |
|
data_iterator, |
|
model, |
|
iteration, |
|
verbose=False, |
|
timers=None, |
|
chart_name="validation", |
|
reference_model=None, |
|
): |
|
"""Helper function to evaluate and dump results on screen.""" |
|
total_loss_dict = evaluate( |
|
neox_args=neox_args, |
|
forward_step_fn=forward_step_func, |
|
data_iterator=data_iterator, |
|
model=model, |
|
verbose=verbose, |
|
timers=timers, |
|
reference_model=reference_model, |
|
) |
|
string = f" {chart_name} results at {prefix} | " |
|
for k, v in total_loss_dict.items(): |
|
if isinstance(v, dict): |
|
if neox_args.eval_tasks and "results" in v: |
|
v = v["results"] |
|
print(v) |
|
for k2, v2 in v.items(): |
|
k3 = "_".join([k, k2]) |
|
string += f"{k3} value: {v2:.6E} | " |
|
tb_wandb_log( |
|
f"{chart_name}/{k3}", |
|
v2, |
|
iteration, |
|
use_wandb=neox_args.use_wandb, |
|
tensorboard_writer=neox_args.tensorboard_writer, |
|
comet_experiment=neox_args.comet_experiment, |
|
) |
|
else: |
|
string += f"{k} value: {v:.6E} | " |
|
tb_wandb_log( |
|
f"{chart_name}/{k}", |
|
v, |
|
iteration, |
|
use_wandb=neox_args.use_wandb, |
|
tensorboard_writer=neox_args.tensorboard_writer, |
|
comet_experiment=neox_args.comet_experiment, |
|
) |
|
|
|
length = len(string) + 1 |
|
print_rank_0("-" * length) |
|
print_rank_0(string) |
|
print_rank_0("-" * length) |
|
|
|
|
|
def save_snapshot(neox_args): |
|
assert ( |
|
neox_args.memory_profiling_path is not None |
|
), "Must pass memory_profiling_path config arg to use profiling" |
|
snapshot = torch.cuda.memory._snapshot() |
|
snapshot_path = os.path.join(neox_args.memory_profiling_path) |
|
if not os.path.exists(snapshot_path): |
|
os.makedirs(snapshot_path) |
|
with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f: |
|
dump(snapshot, f) |
|
|