NEOX / megatron /logging.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
16.4 kB
# Copyright (c) 2024, EleutherAI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
try:
import wandb
except ModuleNotFoundError:
pass
from megatron import mpu, print_rank_0
from megatron.utils import report_memory
import math
class Tee:
"""Duplicate output to both stdout/err and file"""
def __init__(self, file, err: bool = False) -> None:
self.file = open(file, "w")
self.err = err
if not err:
self.std = sys.stdout
sys.stdout = self
else:
self.std = sys.stderr
sys.stderr = self
def __del__(self) -> None:
if not self.err:
sys.stdout = self.std
else:
sys.stderr = self.std
self.file.close()
def write(self, data) -> None:
try:
self.file.write(data)
except OSError:
pass
try:
self.std.write(data)
except OSError:
pass
def flush(self) -> None:
try:
self.file.flush()
except OSError:
pass
def human_readable_flops(num) -> str:
for unit in [
"",
"KFLOPS",
"MFLOPS",
"GFLOPS",
"TFLOPS",
"PFLOPS",
"EFLOPS",
"ZFLOPS",
]:
if abs(num) < 1000.0:
return "%3.1f%s" % (num, unit)
num /= 1000.0
return "%.1f%s" % (num, "Yi")
def get_flops(neox_args, iter_time_s) -> float:
"""
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
"""
world_size = torch.distributed.get_world_size()
vocab_size = neox_args.padded_vocab_size
batch_size = neox_args.train_batch_size
seq_len = neox_args.seq_length
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3
if "rwkv" in neox_args.attention_config:
num_heads = neox_args.num_attention_heads
flops_per_iteration = (
batch_size
* seq_len
* (
78 * hidden_size * hidden_size * num_layers
+ 84 * hidden_size * num_layers
+ 16 * hidden_size
+ 12 * hidden_size * vocab_size
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
# from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here
conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = (
ckpt_activations_factor
* d_inner
* seq_len
* batch_size
* (11 * d_state + 4 * dt_rank + 1)
)
mamba_projectors_flops = (
ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size
)
mamba_conv_flops = (
ckpt_activations_factor
* seq_len
* batch_size
* 2
* d_inner
* conv_dimension
)
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
flops_per_iteration = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
)
)
return flops_per_iteration / (iter_time_s * world_size)
def training_log(
neox_args,
timers,
loss_dict,
total_loss_dict,
learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
model,
optimizer,
noise_scale_logger,
):
"""Log training information such as losses, timing, etc."""
# Update losses.
skipped_iters_key = "skipped iterations"
total_loss_dict[skipped_iters_key] = (
total_loss_dict.get(skipped_iters_key, 0) + skipped_iter
)
got_nan_key = "got nan"
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(key, 0.0) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float("inf") or value == -float("inf") or value != value
got_nan = got_nan or is_nan
total_loss_dict[got_nan_key] = total_loss_dict.get(got_nan_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
if not neox_args.is_pipe_parallel:
add_to_logging("forward")
add_to_logging("backward")
add_to_logging("backward-backward")
add_to_logging("backward-allreduce")
add_to_logging("backward-master-grad")
add_to_logging("backward-clip-grad")
add_to_logging("optimizer")
add_to_logging("batch generator")
# Log timer info to tensorboard and wandb
normalizer = iteration % neox_args.log_interval
if normalizer == 0:
normalizer = neox_args.log_interval
if torch.distributed.get_rank() == 0:
timers.write(
names=timers_to_log, iteration=iteration, normalizer=normalizer
)
else:
# with pipeline parallel, the megatron timers are overridden by the deepspeed ones.
# Try to grab timer values from model engine. Only recently added to deeperspeed, so check that the engine
# has that attribute first
if hasattr(model, "timer_values") and model.timer_values is not None:
if (
model.wall_clock_breakdown()
and model.global_steps % model.steps_per_print() == 0
):
timer_values = model.timer_values
# deepspeed already logs to tensorboard / prints values, so just log to wandb
if neox_args.use_wandb and torch.distributed.get_rank() == 0:
for key in timer_values:
tb_wandb_log(
f"timers/{key}",
timer_values[key],
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# write losses, lr, etc. every step
tb_wandb_log(
"train/learning_rate",
learning_rate,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
for key in loss_dict:
tb_wandb_log(
f'train/{key.replace(" ", "_")}',
loss_dict[key],
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
if neox_args.fp16:
tb_wandb_log(
f"train/loss_scale",
loss_scale,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# log gradient noise scale
if neox_args.log_gradient_noise_scale:
if noise_scale_logger.noise_scale is not None:
tb_wandb_log(
f"train/noise_scale",
noise_scale_logger.noise_scale,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# (optional) Log optimizer states to wandb / tb every step
if neox_args.log_optimizer_states:
for k, v in optimizer.state_dict()["optimizer_state_dict"]["state"].items():
for ki, vi in v.items(): # step, module
if ki != "step":
opt_state_norm = torch.norm(vi) if hasattr(vi, "dim") else vi
tb_wandb_log(
f"optimizer_state_norms/{k}_{ki}",
opt_state_norm,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# (optional) Log grad/param norms to wandb / tb every step
if (
neox_args.log_grad_pct_zeros
or neox_args.log_grad_norm
or neox_args.log_param_norm
):
if neox_args.log_grad_pct_zeros or neox_args.log_grad_norm:
model.store_gradients = True # start storing gradients
for i, (name, param) in enumerate(model.module.named_parameters()):
if neox_args.log_grad_pct_zeros:
if (
hasattr(model, "stored_gradients")
and model.stored_gradients is not None
):
grad = model.stored_gradients[i]
if grad is not None:
tb_wandb_log(
f"pct_grad_zeros/{name}",
(grad == 0).float().mean().item() * 100,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
all_ranks=True,
)
if neox_args.log_grad_norm:
if (
hasattr(model, "stored_gradients")
and model.stored_gradients is not None
):
grad = model.stored_gradients[i]
if grad is not None:
tb_wandb_log(
f"gradient_norms/{name}",
torch.norm(grad),
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
all_ranks=True,
)
if neox_args.log_param_norm:
tb_wandb_log(
f"parameter_norms/{name}",
torch.norm(param),
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
all_ranks=True,
)
if iteration % neox_args.log_interval == 0:
# log other stuff every neox_args.log_interval iters
elapsed_time = timers("interval time").elapsed()
iteration_time = elapsed_time / neox_args.log_interval
samples_per_sec = neox_args.train_batch_size / iteration_time
log_string = " samples/sec: {:.3f} |".format(samples_per_sec)
tb_wandb_log(
"runtime/samples_per_sec",
samples_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
tb_wandb_log(
"runtime/iteration_time",
iteration_time,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
log_string += " iteration {:8d}/{:8d} |".format(
iteration, neox_args.train_iters
)
log_string += " elapsed time per iteration (ms): {:.1f} |".format(
elapsed_time * 1000.0 / neox_args.log_interval
)
log_string += " learning rate: {:.3E} |".format(learning_rate)
num_iterations = max(
1, neox_args.log_interval - total_loss_dict[skipped_iters_key]
)
# log curriculum learning
if neox_args.curriculum_learning:
tb_wandb_log(
"curriculum_seqlen",
neox_args.curriculum_seqlen,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# log tflop / gpu
flops_per_s_per_gpu = get_flops(neox_args, iteration_time)
log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
)
tb_wandb_log(
"runtime/flops_per_sec_per_gpu",
flops_per_s_per_gpu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
v = (
total_loss_dict[key].item()
if hasattr(total_loss_dict[key], "item")
else total_loss_dict[key]
)
avg = v / float(num_iterations)
log_string += " {}: {:.6E} |".format(key, avg)
total_loss_dict[key] = 0.0
if neox_args.precision == "fp16":
log_string += " loss scale: {:.1f} |".format(loss_scale)
log_string += " number of skipped iterations: {:3d} |".format(
total_loss_dict[skipped_iters_key]
)
log_string += " number of nan iterations: {:3d} |".format(
total_loss_dict[got_nan_key]
)
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string)
if report_memory_flag:
report_memory("after {} iterations".format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=neox_args.log_interval)
return report_memory_flag
def tb_wandb_log(
key: str,
value: float,
iteration_no: int,
use_wandb: bool,
tensorboard_writer=None,
comet_experiment=None,
all_ranks: bool = False,
):
# logs to both tb and wandb (if present) from the zeroth rank
do_log = torch.distributed.get_rank() == 0 or all_ranks
if do_log and value is not None:
if tensorboard_writer:
tensorboard_writer.add_scalar(key, value, iteration_no)
if use_wandb:
wandb.log({key: value}, step=iteration_no)
if comet_experiment:
comet_experiment.__internal_api__log_metric__(
key, value, framework="gpt-neox", step=iteration_no
)