NEOX / megatron /training.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
64.4 kB
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version
#
"""Pretrain utilities."""
from datetime import datetime
from functools import partial
from collections import defaultdict
import math
import sys
from contextlib import nullcontext
import torch
import torch.nn.functional as F
import deepspeed
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
import numpy as np
from megatron.utils import (
Timers,
init_wandb,
get_ltor_masks_and_position_ids,
reduce_losses,
)
from megatron import print_rank_0, mpu
from megatron.model import (
GPT2ModelPipe,
SoftEmbedding,
get_params_for_weight_decay_optimization,
mark_norms_for_sequence_parallel_grad_sync,
)
from megatron.mpu.mappings import gather_from_model_parallel_region
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.data.data_utils import (
build_train_valid_test_data_loaders,
shift_and_wrap_data_loaders,
)
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.logging import tb_wandb_log, training_log
from megatron.utils import (
OverflowMonitor,
get_noise_scale_logger,
get_total_params,
CharCounter,
)
from megatron.model.gpt2_model import cross_entropy
from megatron.mpu import vocab_parallel_cross_entropy
from pickle import dump
import os
def mup_weights_reinit(neox_args, model):
def has_method(o, name):
return callable(getattr(o, name, None))
for layer in model.modules():
# This normally would happen in set_base_shapes if we actually were able to use the MuReadout class
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
layer._rescale_parameters()
if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
def save_base_shapes(neox_args, base_shapes, use_cache):
# Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here
neox_args.use_mup = False
base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
if not neox_args.is_pipe_parallel:
base_model = base_model.to_sequential()
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
base_shapes = mup.get_shapes(base_model)
del base_model
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
if not neox_args.is_pipe_parallel:
delta_model = delta_model.to_sequential()
delta_shapes = mup.get_shapes(delta_model)
# change back
neox_args.use_mup = True
neox_args.hidden_size = old_hidden_size
save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
print(f"saving base shapes at {save_shapes}")
mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes)
print(f"base shapes saved...exiting")
sys.exit(1)
def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator):
from megatron.mup_substitute import get_coord_data
from mup.coord_check import plot_coord_data
def lazy_model(hidden_size):
def gen():
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = hidden_size
model, optimizer, _, _ = setup_model_and_optimizer(
neox_args=neox_args, use_cache=False
)
neox_args.hidden_size = old_hidden_size
return model
return gen
models = {}
# Hidden size needs to be divisible by num attention heads
for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
models[hidden_size] = lazy_model(hidden_size)
neox_args.use_mup = True
df_up = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True
)
neox_args.use_mup = False
df_sp = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False
)
plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg")
plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg")
print_rank_0("Saved coord check plots... exiting")
sys.exit(1)
def update_iterations(neox_args, data_loaders):
"""
Compute the number of train iterations if not specified and num_epochs, updates the neox_args object.
Note that if len(train_dataloader) % gradient_accumulation_steps != 0, this will configure neox
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
times.
"""
if (not neox_args.do_train) or (neox_args.train_iters is not None):
pass
elif neox_args.train_iters is None and neox_args.train_epochs is None:
print_rank_0(
"ERROR:Failed to specify either train_epochs or train_iters in config file"
)
else:
global_rank = torch.distributed.get_rank()
if global_rank == 0:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
train_dataloader_len = len(train_dataloader)
train_iterations = (
train_dataloader_len * train_epochs
) // gradient_accumulation_steps
train_iters_tensor = torch.cuda.LongTensor([train_iterations])
else:
train_iters_tensor = torch.cuda.LongTensor([0])
torch.distributed.broadcast(train_iters_tensor, src=0)
neox_args.train_iters = train_iters_tensor[0].item()
print_rank_0(
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs."
)
def pretrain(neox_args):
"""Main training program.
This function will run the following in the order provided:
1) initialize Megatron.
2) get train/val/test datasets.
3) setup model, optimizer and lr schedule.
4) configure data loading
5) train the model.
Arguments:
neox_args: an instance of NeoXArgs containing the configuration for pretrain
"""
# setup logging and timers
init_wandb(neox_args=neox_args)
timers = Timers(
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# Initialize and get arguments, timers, and Tensorboard writer.
initialize_megatron(neox_args=neox_args)
# Create data loaders
timers("train/valid/test data loaders").start()
data_loaders = build_train_valid_test_data_loaders(neox_args=neox_args)
update_iterations(neox_args=neox_args, data_loaders=data_loaders)
timers("train/valid/test data loaders").stop()
# Model, optimizer, and learning rate.
timers("model and optimizer").start()
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
neox_args=neox_args, use_cache=False, iteration=neox_args.iteration
)
timers("model and optimizer").stop()
# Make and configure iterators
timers("train/valid/test data iterators").start()
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = shift_and_wrap_data_loaders(neox_args=neox_args, data_loaders=data_loaders)
timers("train/valid/test data iterators").stop()
if neox_args.use_mup and neox_args.coord_check:
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test data loaders",
"model and optimizer",
"train/valid/test data iterators",
]
)
print_rank_0("training ...")
iteration = neox_args.iteration
# edge case: save step 0 checkpoint if requested and we're starting from step 0
if (
neox_args.save
and neox_args.extra_save_iters
and 0 in neox_args.extra_save_iters
and iteration == 0
):
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.do_train and neox_args.train_iters > 0:
iteration = train(
neox_args=neox_args,
timers=timers,
model=model,
reference_model=reference_model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_data_iterator=train_data_iterator,
valid_data_iterator=valid_data_iterator,
)
if neox_args.do_valid:
prefix = "the end of training for val data"
evaluate_and_print_results(
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=valid_data_iterator,
model=model,
iteration=iteration,
verbose=False,
timers=timers,
reference_model=reference_model,
)
if neox_args.save and iteration != 0:
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.do_test:
# Run on test data.
prefix = "the end of training for test data"
evaluate_and_print_results(
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=test_data_iterator,
model=model,
iteration=iteration,
verbose=True,
timers=timers,
chart_name="test",
reference_model=reference_model,
)
def _get_batch(neox_args, tokenizer, keys, data, datatype, label_mask_zero=False):
"""Support function for get_batch / get_batch pipe (to avoid code repetition)"""
data_b = mpu.broadcast_data(keys, data, datatype)
token_key = keys[0]
label_key = keys[1] if len(keys) > 1 else None
# Unpack.
tokens_ = data_b[token_key].long()
if label_key in data_b:
label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous()
labels = torch.where(
data_b[label_key].long() >= 0,
data_b[label_key].long(),
torch.zeros_like(data_b[label_key].long()),
)[:, 1:].contiguous()
else:
label_mask = (tokens_.long() >= 0)[:, 1:].contiguous()
labels = tokens_[:, 1:].contiguous()
if label_mask_zero:
labels = labels * label_mask
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
sliding_window_width=neox_args.sliding_window_width,
)
# combine loss masks from get_ltor_masks_and_position_ids with loss masks from data
loss_mask = label_mask.to(loss_mask.dtype) * loss_mask
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch(neox_args, data_iterator):
"""Generate a batch"""
# Items and their type.
if neox_args.train_impl in ["normal", "kto"]:
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
else [["pos"], ["neg"]]
)
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
if neox_args.train_impl == "normal":
return _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert (
neox_args.train_micro_batch_size_per_gpu > 1
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
# Remove the last token from the reward since we predict the next token, so
# Reward of <current prediction> will be based on the label of <next token>
rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][
:, :-1
].contiguous()
ref_data = (
mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous()
if neox_args.precompute_model_name
else None
)
return tup + (rw_data, ref_data)
elif neox_args.train_impl in ["dpo", "rm"]:
pos_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys[0],
data=data,
datatype=datatype,
label_mask_zero=True,
)
neg_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys[1],
data=data,
datatype=datatype,
label_mask_zero=True,
)
if neox_args.precompute_model_name:
ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float)
else:
ref_data = {"pos_ref": None}
return [
torch.cat((pos_item, neg_item), dim=0)
for pos_item, neg_item in zip(pos_tup, neg_tup)
] + [
torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[
:, :-1
].contiguous()
if ref_data["pos_ref"] is not None
else None
]
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
assert neox_args.train_impl not in [
"kto",
"dpo",
"rm",
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"
# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
datatype = torch.int64
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
neox_args, neox_args.tokenizer, keys, data, datatype
)
if curr_scheduler is not None:
# iteration + 1 to align with how/when DeepSpeed updates the buffers
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
if curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
# input_ids = input_ids[:, :curriculum_seqlen].contiguous()
tokens = tokens[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
if labels is not None:
labels = labels[:, :curriculum_seqlen].contiguous()
if loss_mask is not None:
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[
:, :, :curriculum_seqlen, :curriculum_seqlen
].contiguous()
# unpack data
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def get_batch_sequential(forward_input, neox_args):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=forward_input[0],
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
return (forward_input[0], forward_input[1], attention_mask)
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()
)
return averaged_losses
def mb_moe_loss_func(args, loss_mask, output_tensor=None):
from megatron.model import megablocks_utils
from megatron.model.megablocks_utils import moe
# NOTE: For pipeline parallelism this function will be run on the
# non-final stages to calculate load balancing loss contribution
# for the MoE layers within the stage. For these cases, output_tensor
# will be None.
loss, loss_dict = (None, {})
if False:
assert output_tensor is not None
loss, loss_dict = loss_func(loss_mask, output_tensor)
assert loss.numel() == 1
# NOTE: If recompute is enabled we will collect duplicate load
# balancing loss contributions. Prune these before calculating
# the load balancing loss.
if args.checkpoint_activations:
# Ignore load balancing loss contributions compute during
# the forward pass if recompute is turned on.
load_balancing_loss_data = moe.get_load_balancing_loss()
if args.num_layers * 2 == len(load_balancing_loss_data):
load_balancing_loss_data = load_balancing_loss_data[args.num_layers :]
moe.clear_load_balancing_loss()
for x in load_balancing_loss_data:
moe.save_load_balancing_loss(x)
# Compute the load balancing loss for all MoE layers.
megablocks_args = args = megablocks_utils.as_megablocks_args(args)
lbl = moe.batched_load_balancing_loss(megablocks_args)
moe.clear_load_balancing_loss()
# Average the load balancing loss across data parallel
# replicas and save for logging.
averaged_lbl = average_losses_across_data_parallel_group([lbl])
loss_dict["load balancing loss"] = averaged_lbl[0]
return averaged_lbl, loss_dict
def get_logp(logits, labels, force_fp32=False):
# Rather than reimplementing logp, cross entropy loss is actually logp, just inverted.
if force_fp32:
logits = logits.float()
return -vocab_parallel_cross_entropy(logits, labels)
def get_pos_neg_logp(logits, labels, force_fp32=False):
# Rather than reimplementing logp, cross entropy loss is actually logp, just inverted.
if force_fp32:
logits = logits.float()
return torch.chunk(-vocab_parallel_cross_entropy(logits, labels), 2, 0)
def forward_step(
data_iterator,
model,
neox_args,
timers,
return_logits=False,
is_train=False,
reference_model=None,
):
"""Forward step."""
if neox_args.is_pipe_parallel:
return model.eval_batch(data_iterator, return_logits=return_logits)
# Get the batch.
if neox_args.memory_profiling and neox_args.iteration:
torch.cuda.nvtx.range_push(f"Get batch")
if timers is not None:
timers("batch generator").start()
if neox_args.train_impl == "normal":
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
elif neox_args.train_impl == "kto":
(
tokens,
labels,
loss_mask,
attention_mask,
position_ids,
rewards,
ref_logp,
) = get_batch(neox_args=neox_args, data_iterator=data_iterator)
if neox_args.train_impl in ["dpo", "rm"]:
tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
if timers is not None:
timers("batch generator").stop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_push(f"Forward pass")
metrics = {}
if neox_args.train_impl == "normal":
# Sequential returns moe_losses, but this is not yet supported by pipe parallel
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, moe_losses = maybe_tuple
else:
outputs = maybe_tuple
moe_losses = []
if (
is_train
and neox_args.curriculum_learning
and neox_args.curriculum_seqlen < neox_args.seq_length
):
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
main_loss = cross_entropy(
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
)
if neox_args.moe_num_experts > 1:
if neox_args.moe_type == "deepspeed":
moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses)
elif neox_args.moe_type == "megablocks":
moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0]
else:
raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}")
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.inference_mode():
# So we can gather token logps...
token_logp_labels = labels.clone()
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
if neox_args.dpo_reference_free:
ref_pos = 0
ref_neg = 0
elif ref_logp is None:
ref_maybe_tuple = reference_model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(ref_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
ref_outputs, _ = ref_maybe_tuple
else:
ref_outputs = ref_maybe_tuple
ref_pos, ref_neg = get_pos_neg_logp(
ref_outputs, token_logp_labels, neox_args.dpo_fp32
)
else:
ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0)
ref_pos = (ref_pos * pos_loss_mask).sum(-1)
ref_neg = (ref_neg * neg_loss_mask).sum(-1)
chosen_maybe_tuple = model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(chosen_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
chosen_outputs, _ = chosen_maybe_tuple
else:
chosen_outputs = chosen_maybe_tuple
chosen_pos, chosen_neg = get_pos_neg_logp(
chosen_outputs, token_logp_labels, neox_args.dpo_fp32
)
chosen_pos = (chosen_pos * pos_loss_mask).sum(-1)
chosen_neg = (chosen_neg * neg_loss_mask).sum(-1)
with torch.no_grad():
# Collect metrics...
if not neox_args.dpo_reference_free:
metrics["ref_neg"] = ref_neg.clone().detach().mean()
metrics["ref_pos"] = ref_pos.clone().detach().mean()
metrics["chosen_neg"] = chosen_neg.clone().detach().mean()
metrics["chosen_pos"] = chosen_pos.clone().detach().mean()
if not neox_args.dpo_reference_free:
chosen_rewards = neox_args.dpo_beta * (
chosen_pos.clone().detach() - ref_pos.clone().detach()
)
rejected_rewards = neox_args.dpo_beta * (
chosen_neg.clone().detach() - ref_neg.clone().detach()
)
metrics["chosen_rewards"] = chosen_rewards.mean()
metrics["rejected_rewards"] = rejected_rewards.mean()
reward_acc = (chosen_rewards > rejected_rewards).float()
metrics["reward_acc"] = reward_acc.mean()
metrics["margins"] = (chosen_rewards - rejected_rewards).mean()
pi_logrations = chosen_pos - chosen_neg
ref_logrations = ref_pos - ref_neg
logits = pi_logrations - ref_logrations
loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean()
elif neox_args.train_impl == "kto":
# Based on https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py
# Except we don't have an extra input for KL logp, we just split the batch in half
with torch.no_grad():
# So we can gather token logps...
token_logp_labels = labels.clone()
token_logp_labels[token_logp_labels == -100] = 0
if ref_logp is None:
# Did not precompute logits....
ref_maybe_tuple = reference_model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(ref_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
ref_outputs, _ = ref_maybe_tuple
else:
ref_outputs = ref_maybe_tuple
# gather across tensor parallel group
ref_outputs = gather_from_model_parallel_region(ref_outputs)
ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32)
else:
print(f"REF LOGP: {ref_logp.clone().detach().mean()}")
ref_logp = ref_logp * loss_mask
scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight
scaling += (
rewards.sum(-1) < -0.001
).float() * neox_args.kto_undesirable_weight
pos_mask = (rewards > 0.001).float()
neg_mask = (rewards < -0.001).float()
chosen_maybe_tuple = model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(chosen_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
chosen_outputs, _ = chosen_maybe_tuple
else:
chosen_outputs = chosen_maybe_tuple
chosen_outputs = gather_from_model_parallel_region(chosen_outputs)
chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32)
chosen_logp = chosen_logp * loss_mask
with torch.no_grad():
# Collect metrics...
metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean()
metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean()
metrics["pos_ref_logp"] = (
(ref_logp * pos_mask).clone().detach().sum(-1).mean()
)
metrics["neg_ref_logp"] = (
(ref_logp * neg_mask).clone().detach().sum(-1).mean()
)
metrics["pos_policy_logp"] = (
(chosen_logp * pos_mask).clone().detach().sum(-1).mean()
)
metrics["neg_policy_logp"] = (
(chosen_logp * neg_mask).clone().detach().sum(-1).mean()
)
metrics["kl"] = (
chosen_logp.clone().detach() - ref_logp.clone().detach()
).sum() / loss_mask.sum()
policy_rewards = (
neox_args.kto_beta
* rewards
* (chosen_logp.clone().detach() - ref_logp.clone().detach())
)
reward_acc = (policy_rewards.sum(-1) > 0.0).float()
metrics["reward_acc"] = reward_acc.mean()
metrics["policy_rewards"] = policy_rewards.sum()
print(metrics)
pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0)
ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0)
reward1, reward2 = torch.chunk(rewards, 2, 0)
scaling1, scaling2 = torch.chunk(scaling, 2, 0)
kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean()
kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean()
log_ratio1 = pol_logp1 - ref_logp1
log_ratio2 = pol_logp2 - ref_logp2
# TODO: Add pack_until_overflow sequence support
loss = (
0.5
* scaling1.mean(-1)
* (
1
- F.sigmoid(
(
neox_args.kto_beta
* reward1.mean(-1)
* (log_ratio1.sum(-1) - kl2.clone().detach())
)
)
)
) + (
0.5
* scaling2.mean(-1)
* (
1
- F.sigmoid(
(
neox_args.kto_beta
* reward2.mean(-1)
* (log_ratio2.sum(-1) - kl1.clone().detach())
)
)
)
)
# print(loss.shape)
loss = loss.mean()
# print(loss.shape)
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if return_logits:
return loss, outputs, metrics
return loss, metrics
def get_model(neox_args, use_cache=False):
"""Build the model."""
# Build model on cpu.
print_rank_0("building GPT2 model ...")
# Temporarily disable mup so that the base model does not use the mup init functions before set_base_shapes is called below.
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
if neox_args.zero_stage in [2, 3]:
if neox_args.pipe_parallel_size == 1:
print_rank_0(
"ZeRO stage 2/3 and the PipelineModule are incompatible, please set 'pipe_parallel_size' to 0 instead"
)
exit()
if neox_args.pipe_parallel_size > 1:
print_rank_0(
"ZeRO stage 2/3 and pipeline paralleism are not supported simultaneously"
)
exit()
if neox_args.model_parallel_size > 1:
print_rank_0(
"ZeRO stage 2/3 and model paralleism are not currently supported simultaneously"
)
exit()
with deepspeed.zero.Init(
config_dict_or_path=neox_args.deepspeed_config
) if neox_args.zero_stage == 3 else nullcontext() as gs:
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
### soft prompt tuning stuff ###
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
"enabled", False
):
soft_prompt = SoftEmbedding(
neox_args,
wte=getattr(model, "0").word_embeddings,
n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10),
init_string=neox_args.soft_prompt_tuning.get("init_string", ""),
init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5),
)
model.insert_layers(
layers=soft_prompt, idx=1
) # insert the soft prompt layer directly after the word embeddings
# freeze everything but the soft prompt
for name, param in model.named_parameters():
if not "soft_embedding" in name:
param.requires_grad = False
if not neox_args.is_pipe_parallel:
# Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training
model = model.to_sequential()
neox_args.use_mup = old_use_mup
if neox_args.use_mup:
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
if neox_args.save_base_shapes:
save_base_shapes(neox_args, base_shapes, use_cache)
mup.set_base_shapes(model, base_shapes)
# Call the mup replacement init functions on the model now that set_base_shapes has given each weight a .infshape attribute
mup_weights_reinit(neox_args, model)
if neox_args.deepspeed:
# DeepSpeed handles CUDA, FP16, and DDP components.
return model
else:
raise ValueError("Must be using deepspeed to run neox")
def get_optimizer(model, neox_args, dummy=False):
"""Set up the optimizer."""
if neox_args.no_load_optim and neox_args.deepspeed:
# Required to have something so...
dummy = True
neox_args.optimizer = {"params": {"lr": 0.0}}
neox_args.optimizer_type = "adam"
elif neox_args.no_load_optim:
return None, None
if neox_args.optimizer is None:
print_rank_0(
f"ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)"
)
exit()
# Build parameter groups (weight decay and non-decay).
param_groups = get_params_for_weight_decay_optimization(model, neox_args)
print_rank_0(
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}'
)
if neox_args.create_moe_param_group:
from deepspeed.moe.utils import (
is_moe_param,
split_params_into_different_moe_groups_for_optimizer,
)
param_groups = split_params_into_different_moe_groups_for_optimizer(
param_groups
)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group["params"]:
if not hasattr(param, "model_parallel"):
param.model_parallel = False
# Filter out params that don't require a grad (for soft prompt tuning, etc.)
_param_groups = []
for param_group in param_groups:
trainable_params = [p for p in param_group["params"] if p.requires_grad]
if dummy:
trainable_params = [trainable_params[0]] # just take the first one
param_group["params"] = trainable_params
_param_groups.append(param_group)
if dummy:
# Only need one.
break
param_groups = _param_groups
# If we're using mup, then the optimizer must be adam or sgd
assert not neox_args.use_mup or (
neox_args.optimizer_type.lower() == "adam"
or neox_args.optimizer_type.lower() == "sgd"
), f"If use_mup == True, you must specify either the adam or sgd optimizers. You passed: {neox_args.optimizer_type.lower()}"
if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]:
if neox_args.optimizer == "cpu_torch_adam":
cpu_adam_optimizer = torch.optim.Adam
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
cpu_adam_optimizer = DeepSpeedCPUAdam
optimizer = cpu_adam_optimizer(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "onebitadam":
assert neox_args.deepspeed
optimizer = None
# onebitadam needs to be instantiated within the deepspeed engine to work :|
elif neox_args.optimizer_type.lower() == "sm3":
from .optimizers import SM3
optimizer = SM3(param_groups, **neox_args.optimizer["params"])
elif neox_args.optimizer_type.lower() == "madgrad_wd":
from .optimizers import madgrad_wd
optimizer = madgrad_wd(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "lion":
# if we want the deepspeed zero lion...megatron lion will throw DeepSpeed Error
if neox_args.zero_optimization["stage"] != 0:
from deepspeed.ops.lion import FusedLion
lion_optimizer = FusedLion
# if not zero
else:
from .optimizers import Lion
lion_optimizer = Lion
optimizer = lion_optimizer(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "adam":
# Use Adam
if neox_args.use_mup:
try:
from mup import MuAdam
adam_optimizer = MuAdam
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
else:
if neox_args.use_bnb_optimizer:
try:
import bitsandbytes as bnb
adam_optimizer = bnb.optim.Adam8bit
except ModuleNotFoundError:
print(
"Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes."
)
raise Exception
else:
try:
# default to apex as it's slightly faster
from apex.optimizers import FusedAdam as Adam
except ImportError:
# if apex isn't installed, use deepspeed's FusedAdam
print(
"WARNING: APEX not installed - defaulting to deepspeed's fused adam"
)
from deepspeed.ops.adam import FusedAdam as Adam
adam_optimizer = Adam
optimizer = adam_optimizer(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "sgd":
try:
from mup import MuSGD
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
optimizer = MuSGD(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
else:
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized")
if neox_args.deepspeed:
# fp16 wrapper is not required for DeepSpeed.
return optimizer, param_groups
else:
raise ValueError("Must be using deepspeed to run neox")
def get_learning_rate_scheduler(optimizer, neox_args):
"""Build the learning rate scheduler."""
if (neox_args.no_load_optim) and not neox_args.deepspeed:
# TODO: this should be configured as a separate arg
return None
if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam":
print_rank_0(
"WARNING: onebitadam requires the lr scheduler be built by deepspeed - "
"Make sure one is added to your deepspeed config"
)
return None
# Add linear learning rate scheduler.
if neox_args.lr_decay_iters is not None:
num_iters = neox_args.lr_decay_iters
elif neox_args.lr_decay_fraction is not None:
num_iters = math.floor(neox_args.train_iters * neox_args.lr_decay_fraction)
else:
num_iters = neox_args.train_iters
num_iters = max(1, num_iters)
init_step = 0
warmup_iter = neox_args.warmup * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
warmup_iter=warmup_iter,
total_iters=num_iters,
decay_style=neox_args.lr_decay_style,
last_iter=init_step,
min_lr=neox_args.min_lr,
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
override_lr_scheduler=neox_args.override_lr_scheduler,
use_mup=neox_args.use_mup,
)
return lr_scheduler
def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
"""Setup memory profiler"""
if neox_args.memory_profiling:
torch.cuda.memory._record_memory_history(
True,
# keep a maximum 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
trace_alloc_record_context=True,
)
"""Setup model and optimizer."""
needs_reference_model = (
(neox_args.train_impl == "dpo")
and (neox_args.precompute_model_name is None)
and (not neox_args.dpo_reference_free)
) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None))
model = get_model(neox_args=neox_args, use_cache=use_cache)
if needs_reference_model:
reference_model = get_model(neox_args=neox_args, use_cache=use_cache)
else:
reference_model = None
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args)
lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args)
if neox_args.deepspeed and needs_reference_model:
# Need an optimizer & lr_scheduler so make a very small one to keep deepspeed happy...
ref_optimizer, ref_param_groups = get_optimizer(
model=reference_model, neox_args=neox_args, dummy=True
)
ref_lr_scheduler = get_learning_rate_scheduler(
optimizer=ref_optimizer, neox_args=neox_args
)
else:
ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None
if neox_args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
_model_params = param_groups if optimizer is None else None
_lr_scheduler = lr_scheduler
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=neox_args,
lr_scheduler=_lr_scheduler,
dist_init_required=False,
model_parameters=_model_params,
# Need to remove the below so that it doesn't conflict with --deepspeed_config required by autotuning
# config_params=neox_args.deepspeed_config,
mpu=mpu if not neox_args.is_pipe_parallel else None,
)
if needs_reference_model:
reference_model, _, _, _ = deepspeed.initialize(
model=reference_model,
optimizer=ref_optimizer,
args=neox_args,
lr_scheduler=ref_lr_scheduler,
dist_init_required=False,
model_parameters=ref_param_groups,
mpu=mpu if not neox_args.is_pipe_parallel else None,
)
mark_norms_for_sequence_parallel_grad_sync(model, neox_args)
if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks":
# We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE.
model.has_moe_layers = True
model.total_params = get_total_params(model.module)
print_rank_0(f' > total params: {"{:,}".format(model.total_params)}')
if neox_args.is_pipe_parallel:
model.set_has_attention_mask(True)
if neox_args.curriculum_learning:
curr_scheduler = CurriculumScheduler(neox_args.curriculum_learning)
if iteration is not None and iteration > 0:
curr_scheduler.update_difficulty(iteration)
else:
curr_scheduler = None
model.set_batch_fn(
partial(
get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler
)
)
else:
model.module.set_batch_fn(
partial(get_batch_sequential, neox_args=neox_args)
)
else:
raise ValueError("Must be using deepspeed to run neox")
if neox_args.load is not None:
neox_args.iteration = load_checkpoint(
neox_args=neox_args,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
iteration=iteration,
)
if needs_reference_model:
_ = load_checkpoint(
neox_args=neox_args,
model=reference_model,
optimizer=ref_optimizer,
lr_scheduler=ref_lr_scheduler,
iteration=iteration,
)
reference_model.eval()
print_rank_0(
f"Loading checkpoint and starting from iteration {neox_args.iteration}"
)
else:
neox_args.iteration = 0
# need this for correct lr scheduling resume from ckpt
# but it will not exist if this is being called for inference
if lr_scheduler is not None:
lr_scheduler.optimizer = model.optimizer
return model, optimizer, lr_scheduler, reference_model
def backward_step(neox_args, timers, optimizer, model, loss):
"""Backward step."""
# Backward pass.
timers("backward-backward").start()
if neox_args.deepspeed:
model.backward(loss)
else:
raise ValueError("Must be using deepspeed to run neox")
timers("backward-backward").stop()
if neox_args.deepspeed:
# DeepSpeed backward propagation already addressed all reduce communication.
# Reset the timer to avoid breaking timer logs below.
timers("backward-allreduce").reset()
else:
raise ValueError("Must be using deepspeed to run neox")
def train_step(
neox_args,
timers,
data_iterator,
model,
optimizer,
lr_scheduler,
reference_model=None,
):
"""Single training step."""
# Pipeline parallelism schedules forward/backward/step
if neox_args.is_pipe_parallel:
reduced_loss = train_step_pipe(
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
)
reduce_metrics = reduced_loss
if (
neox_args.memory_profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
else:
losses = []
metric_dicts = defaultdict(list)
for _ in range(neox_args.gradient_accumulation_steps):
# Forward model for one step.
timers("forward").start()
loss, metric_dict = forward_step(
neox_args=neox_args,
timers=timers,
data_iterator=data_iterator,
model=model,
is_train=True,
reference_model=reference_model,
)
timers("forward").stop()
losses.append(loss)
for key in metric_dict.keys():
metric_dicts[key].append(metric_dict[key])
# Calculate gradients, reduce across processes, and clip.
if (
neox_args.profile
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Backward pass")
timers("backward").start()
backward_step(
neox_args=neox_args,
timers=timers,
optimizer=optimizer,
model=model,
loss=loss,
)
timers("backward").stop()
if (
neox_args.profile
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
# Update parameters.
if (
neox_args.profile
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Optimizer step")
timers("optimizer").start()
if neox_args.deepspeed:
model.step()
else:
raise ValueError("Must be using deepspeed to run neox")
timers("optimizer").stop()
if (
neox_args.profile
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
if (
neox_args.profile
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
# reduces metrics across machines for logging
reduce_metrics = {
key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys()
}
reduce_metrics["lm_loss"] = reduce_losses(losses).mean()
if neox_args.precision == "fp16" and model.optimizer.overflow:
skipped_iter = 1
else:
skipped_iter = 0
collect_loss_for_unit_test(reduce_metrics["lm_loss"])
return reduce_metrics, skipped_iter
def train_step_pipe(neox_args, timers, model, data_iterator):
"""Single training step with DeepSpeed's pipeline parallel engine."""
assert neox_args.deepspeed
loss = model.train_batch(data_iter=data_iterator)
loss_dict = {"lm_loss": loss}
# Don't break Megatron's timers because we changed code paths.
for t in [
"forward",
"backward",
"allreduce",
"optimizer",
"batch generator",
"data loader",
]:
timers(t).reset()
return loss_dict
def is_save_iter(neox_args, iteration):
if neox_args.extra_save_iters and iteration in neox_args.extra_save_iters:
return True
if neox_args.checkpoint_factor:
if neox_args.checkpoint_scale == "linear":
assert float(
neox_args.checkpoint_factor
).is_integer(), "checkpoint_factor must be a whole number when using linear checkpoint_scale"
return iteration % neox_args.checkpoint_factor == 0
elif neox_args.checkpoint_scale == "log":
# Check if iteration is a power of checkpoint_factor
assert neox_args.checkpoint_factor > 1
power = 1
while power < iteration + 1:
if int(power) == iteration:
return True
power *= neox_args.checkpoint_factor
return False
return False
def train(
neox_args,
timers,
model,
reference_model,
optimizer,
lr_scheduler,
train_data_iterator,
valid_data_iterator,
):
"""Train the model function."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = neox_args.iteration
timers("interval time").start()
report_memory_flag = True
# get noise scale logger (if neox_args.log_gradient_noise_scale is True)
noise_scale_logger = get_noise_scale_logger(neox_args)
# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)
if neox_args.profile:
schedule = torch.profiler.schedule(
wait=neox_args.profile_step_start,
warmup=1,
active=neox_args.profile_step_stop - neox_args.profile_step_start,
)
prof = torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
neox_args.tensorboard_dir
),
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
with_stack=True,
)
prof.start()
while iteration < neox_args.train_iters:
if neox_args.profile:
prof.step()
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
neox_args=neox_args,
timers=timers,
data_iterator=train_data_iterator,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
reference_model=reference_model,
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
prof.stop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
overflow_monitor.check(skipped_iter) # check for repeated overflow
if neox_args.log_gradient_noise_scale: # log noise scale if applicable
noise_scale_logger.update()
# get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you
# may have no tunable parameters on a specific rank
if optimizer.param_groups:
lr = optimizer.param_groups[0].get("lr", 0)
else:
lr = 0
# Logging.
report_memory_flag = training_log(
neox_args=neox_args,
timers=timers,
loss_dict=loss_dict,
total_loss_dict=total_loss_dict,
learning_rate=lr,
iteration=iteration,
loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None,
report_memory_flag=report_memory_flag,
skipped_iter=skipped_iter,
model=model,
optimizer=optimizer,
noise_scale_logger=noise_scale_logger,
)
# Checkpointing
if neox_args.save and is_save_iter(neox_args, iteration):
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Evaluation
if (
neox_args.eval_interval
and iteration % neox_args.eval_interval == 0
and neox_args.do_valid
):
prefix = "iteration {}".format(iteration)
evaluate_and_print_results(
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=valid_data_iterator,
model=model,
iteration=iteration,
verbose=False,
timers=timers,
reference_model=reference_model,
)
if neox_args.exit_interval and iteration % neox_args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
rank = torch.distributed.get_rank()
print_rank_0(
"rank: {} | time: {} | exiting the program at iteration {}".format(
rank, time_str, iteration
)
)
sys.exit()
return iteration
def evaluate(
neox_args,
forward_step_fn,
data_iterator,
model,
verbose=False,
timers=None,
reference_model=None,
):
"""Evaluation.
neox_args: NeoX Arguments
forward_step_fn: function with args `neox_args, timers,
data_iterator & model that will run a forward pass on the model
data_iterator: Iterator that iterates over batches of data. Should return data in the form:
{'text': np.array([tokens], dtype=np.int64)}
where the size of the array is the model's context size + 1
(`get_batch` transforms it into inputs / labels)
"""
# Turn on evaluation mode which disables dropout.
model.eval()
losses = []
metric_dicts = defaultdict(list)
if neox_args.char_level_ppl:
data_iterator = CharCounter(data_iterator, neox_args.tokenizer)
with torch.no_grad():
iteration = 0
while iteration < neox_args.eval_iters:
iteration += 1
if verbose and iteration % neox_args.log_interval == 0:
print_rank_0(
"Evaluating iter {}/{}".format(iteration, neox_args.eval_iters)
)
# although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s
# to be consistent with deepspeed's pipe parallel engine
# since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true
for _ in range(
1
if neox_args.is_pipe_parallel
else neox_args.gradient_accumulation_steps
):
# Forward evaluation
loss, metric_dict = forward_step_fn(
model=model,
data_iterator=data_iterator,
neox_args=neox_args,
timers=timers,
reference_model=reference_model,
)
losses.append(loss)
for key in metric_dict.keys():
metric_dicts[key].append(metric_dict[key])
# When contiguous memory optimizations are enabled, the buffers
# allocated by the optimizations are deallocated during backward pass
# in the absence of backward pass the buffers should be reset after each
# forward pass
if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing:
deepspeed.checkpointing.reset()
# reduces losses across processes for logging & run eval harness tasks
eval_results = {"lm_loss": reduce_losses(losses).mean().item()}
for key in metric_dicts.keys():
eval_results[key] = reduce_losses(metric_dicts[key]).mean().item()
eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"])
if neox_args.char_level_ppl:
# calculate character level perplexity, if specified
# if neox_args.char_level_ppl:
# unwrap the data_iterator
tokens_per_char = data_iterator.tokens_per_char()
print_rank_0(f"Counting chars took {data_iterator.total_time} seconds")
data_iterator = data_iterator.data_iterator
eval_results["lm_loss_char_lvl_ppl"] = math.exp(
eval_results["lm_loss"] * tokens_per_char
)
if neox_args.eval_tasks:
from eval_tasks import run_eval_harness
eval_results.update(
run_eval_harness(
model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks
).get("results")
)
# Move model back to the train mode.
model.train()
return eval_results
def collect_loss_for_unit_test(lm_ss):
# Logic moved to separate function to allow tracking in unit tests with unittest.mock.patch
pass
def evaluate_and_print_results(
neox_args,
prefix,
forward_step_func,
data_iterator,
model,
iteration,
verbose=False,
timers=None,
chart_name="validation",
reference_model=None,
):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(
neox_args=neox_args,
forward_step_fn=forward_step_func,
data_iterator=data_iterator,
model=model,
verbose=verbose,
timers=timers,
reference_model=reference_model,
)
string = f" {chart_name} results at {prefix} | "
for k, v in total_loss_dict.items():
if isinstance(v, dict):
if neox_args.eval_tasks and "results" in v:
v = v["results"]
print(v)
for k2, v2 in v.items():
k3 = "_".join([k, k2])
string += f"{k3} value: {v2:.6E} | "
tb_wandb_log(
f"{chart_name}/{k3}",
v2,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
else:
string += f"{k} value: {v:.6E} | "
tb_wandb_log(
f"{chart_name}/{k}",
v,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
length = len(string) + 1
print_rank_0("-" * length)
print_rank_0(string)
print_rank_0("-" * length)
def save_snapshot(neox_args):
assert (
neox_args.memory_profiling_path is not None
), "Must pass memory_profiling_path config arg to use profiling"
snapshot = torch.cuda.memory._snapshot()
snapshot_path = os.path.join(neox_args.memory_profiling_path)
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f:
dump(snapshot, f)