Andromeda / train.py
kye's picture
Upload 73 files
ca4fc4d
raw
history blame
24.7 kB
import math
import multiprocessing
import os
from datetime import timedelta
from functools import partial
from itertools import chain
import torch
# import bitsandbytes as bnb
from torch.distributed.fsdp import (
FullyShardedDataParallel,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
)
from accelerate import Accelerator
from accelerate.utils import (DummyOptim, InitProcessGroupKwargs)
from accelerate.logging import get_logger
from datasets import load_dataset
from lion_pytorch import Lion
from torch.nn import LayerNorm
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy
)
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AutoTokenizer, default_data_collator,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup, set_seed)
from Andromeda.utils.stable_adamw import StableAdamWUnfused
from Andromeda.core.transformer import Transformer, AndromedaEmbedding
# from Andromeda.model import Andromeda
from Andromeda.model import AndromedaEmbedding #, Andromeda
from Andromeda.configs import Andromeda1Billion
########### SETUP CONFIG
import torch.distributed as dist
from accelerate.state import AcceleratorState
# state = AcceleratorState()
logger = get_logger(__name__, log_level="INFO")
class CFG:
BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY: int = 1
SEED: int = 42
LEARNING_RATE: float = 1e-4 #3e-4 # 1e-4 for lion
WEIGHT_DECAY: float = 0.1
SEQ_LEN: int = 8192
NUM_CPU: int = multiprocessing.cpu_count()
USE_DEEPSPEED: bool = True
USE_FSDP: bool = True
USE_PRETOKENIZED: bool = True
USE_ACTIVATION_CHECKPOINTING: bool = True
RESUME_FROM_CHECKPOINT: str = False
CHECKPOINTING_STEPS: int = 1000
OUTPUT_DIR: str = 'checkpoints/' # Folder
ENTITY_NAME: str = "Andromeda"
LOGGING_STEPS: int = 100
# helpers
def print_num_params(model, accelerator: Accelerator):
# n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print(f"Number of parameters in model: {n_params}")
# activation checkpointing
def activation_checkpointing(
model: torch.nn.Module,
offload_to_cpu: bool = False,
accelerator: Accelerator = None,
):
"""
Apply activation checkpointing to a model.
Args:
model (Module): The model to which to apply activation checkpointing.
offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False.
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
"""
if accelerator is not None:
accelerator.print("Using activation checkpointing")
def check_fn(submodule):
return isinstance(submodule, Transformer)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=offload_to_cpu,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
# FSDP
def fsdp(
model: torch.nn.Module,
auto_wrap: bool = False,
mp: str = "fp32",
shard_strat: str = "NO_SHARD",
):
"""
This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding.
Args:
model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP.
auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False.
mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'.
shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'.
Raises:
ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'.
ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'.
Returns:
torch.nn.Module: The input model wrapped with FSDP.
"""
if auto_wrap:
Andromeda_auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Transformer,
},
)
else:
Andromeda_auto_wrap_policy = None
if mp == "bf16":
mp_fsdp = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
elif mp == "fp16":
mp_fsdp = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
elif mp == "fp32":
mp_fsdp = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
else:
raise ValueError(
"Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format(
mp
)
)
if shard_strat == "SHARD_GRAD":
sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP
elif shard_strat == "FULL_SHARD":
sharding_strat_fsdp = ShardingStrategy.FULL_SHARD
elif shard_strat == "NO_SHARD":
sharding_strat_fsdp = ShardingStrategy.NO_SHARD
else:
raise ValueError(
"Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format(
shard_strat
)
)
model = FullyShardedDataParallel(
model,
auto_wrap_policy=Andromeda_auto_wrap_policy,
mixed_precision=mp_fsdp,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sharding_strategy=sharding_strat_fsdp,
forward_prefetch=True,
use_orig_params=True,
)
return model
# learning rate scheduler
def get_lr_scheduler_with_warmup(
optimizer: torch.optim.Optimizer,
scheduler_type: str,
num_warmup_steps: int,
max_train_steps: int,
grad_accumulate_every: int = 1,
accelerator: Accelerator = None,
):
"""
Get a learning rate scheduler with warmup.
Args:
optimizer (Optimizer): The optimizer for which to create the learning rate scheduler.
scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine".
num_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
max_train_steps (int): The maximum number of training steps.
grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1.
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
Returns:
The learning rate scheduler with warmup.
Raises:
ValueError: If scheduler_type is not "linear" or "cosine".
"""
NUM_WARMUP_STEPS = num_warmup_steps
GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
if accelerator is not None:
accelerator.print(f"Using {scheduler_type} lr scheduler")
if scheduler_type == "linear":
return get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
)
elif scheduler_type == "cosine":
return get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
)
else:
raise ValueError(
"Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
scheduler_type
)
)
# optimizers
def decoupled_optimizer(
model: torch.nn.Module,
learning_rate: float,
weight_decay: float,
beta_1: float,
beta_2: float,
optimizer_type: str,
use_fsdp: bool = True,
accelerator: Accelerator = None,
):
"""
Decouples the optimizer from the training process.
This function sets up the optimizer for the model by creating two groups of parameters:
one for weight decay and one without weight decay. Then, it initializes the optimizer
with these two groups of parameters.
Args:
model (Module): The model whose parameters are optimized.
learning_rate (float): The learning rate for the optimizer.
weight_decay (float): The weight decay for the optimizer.
beta_1 (float): The exponential decay rate for the 1st moment estimates.
beta_2 (float): The exponential decay rate for the 2nd moment estimates.
optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.
Returns:
Optimizer: The initialized optimizer.
Raises:
ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
"""
accelerator.print(f"Using {optimizer_type} optimizer")
# Create an empty dictionary called param_dict to store the model's named parameters.
param_dict = {}
# Iterate over the model's named parameters and populate the param_dict with key-value pairs.
for param_name, param in model.named_parameters():
param_dict[param_name] = param
# Separate the model's named modules into two groups: decay and no_decay.
# Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
no_decay = []
if use_fsdp:
exclude_module = "_fsdp_wrapped_module.token_emb"
else:
exclude_module = "token_emb"
# Iterate through the named modules of the model.
for module_name, module in model.named_modules():
# Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
for ndim in [LayerNorm, torch.nn.Embedding]:
if isinstance(module, ndim):
# If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
if module_name == exclude_module:
no_decay.append(f"{module_name}.weight")
else:
# If the module is an instance of LayerNorm
no_decay.append(f"{module_name}.gamma")
# Exit the inner loop since the desired module has been found.
break
# Create an empty list to store the names of the Linear layer weights with weight decay.
decay = []
# Iterate through the named modules of the model.
for module_name, module in model.named_modules():
# Check if the current module is an instance of the desired type (torch.nn.Linear).
for ndim in [torch.nn.Linear]:
if isinstance(module, ndim):
# If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
decay.append(f"{module_name}.weight")
# Exit the inner loop since the desired module has been found.
break
# Create two separate lists of model parameters: decay_param and no_decay_param.
# The decay_param list contains the parameters that should have weight decay applied.
# The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.
# Create an empty list called decay_param to store the parameters with weight decay.
decay_param = []
if use_fsdp:
exclude_param = "_fsdp_wrapped_module.to_logits.weight"
else:
exclude_param = "to_logits.weight"
# Iterate over the decay list, which contains the names of the parameters with weight decay.
for param in decay:
# Check if the current parameter is not 'to_logits.weight'.
# Append the corresponding parameter from param_dict to the decay_param list.
if param != exclude_param:
decay_param.append(param_dict[param])
# Create an empty list called no_decay_param to store the parameters without weight decay.
no_decay_param = []
# Iterate over the no_decay list, which contains the names of the parameters without weight decay.
for param in no_decay:
try:
# Append the corresponding parameter from param_dict to the no_decay_param list.
no_decay_param.append(param_dict[param])
except KeyError:
# print(f"Parameter {param_name} does not exist in the model")
pass
# Create a list called grouped_params that contains two dictionaries.
# The first dictionary has the decay_param list and the corresponding weight_decay value.
# The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
grouped_params = [
{"params": decay_param, "weight_decay": weight_decay},
{"params": no_decay_param, "weight_decay": 0.0},
]
# Create a variable called optimizer that stores an instance of the optimizer.
if optimizer_type == "lion":
optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
elif optimizer_type == "adamw":
optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
elif optimizer_type == "deepspeed":
optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
elif optimizer_type == "stable_adamw":
optimizer = StableAdamWUnfused(
grouped_params, lr=learning_rate, betas=(beta_1, beta_2),
)
# elif optimizer_type=="Adam8bit":
# optimizer = bnb.optim.Adam8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
# elif optimizer_type=="Lion8Bit":
# optimizer = bnb.optim.Lion8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
else:
raise ValueError(
"Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
optimizer_type
)
)
# Return the optimizer.
return optimizer
# dataloaders
def build_dataloaders():
"""
Build data loaders for training.
This function performs the following steps:
1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model.
2. Load the "openwebtext" dataset.
3. Tokenize the dataset, adding the end-of-sentence token to each text.
4. Process the tokenized dataset into chunks of a specified block size.
Returns:
Dataset: The processed dataset ready for training.
"""
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
dataset = load_dataset("openwebtext", split="train")
tokenized_dataset = dataset.map(
lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
batched=True,
num_proc=CFG.NUM_CPU,
remove_columns=["text"],
)
block_size = CFG.SEQ_LEN
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
train_dataset = tokenized_dataset.map(
group_texts, batched=True, num_proc=CFG.NUM_CPU,
)
return train_dataset
#switch to falconwebdataset
def build_pre_tokenized():
d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]")
# d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
# d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
# d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
# d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
# train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
return d0
def Train():
# accelerator
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
accelerator = Accelerator(
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
mixed_precision="fp16",
log_with="wandb",
kwargs_handlers=[timeout],
)
state = AcceleratorState()
state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE #??????
accelerator.init_trackers(
project_name="Andromeda",
config={
"batch_size": CFG.BATCH_SIZE,
"gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
"learning_rate": CFG.LEARNING_RATE,
"seq_len": CFG.SEQ_LEN,
},
# init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}},
)
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
# set seed
set_seed(CFG.SEED)
# model = Andromeda(
# num_tokens=50432,
# max_seq_len=8192,
# dim=3072,
# depth=24,
# dim_head=128,
# heads=12,
# use_abs_pos_emb=False,
# alibi_pos_bias=True,
# alibi_num_heads=6,
# rotary_xpos=True,
# attn_flash=True,
# shift_tokens=1,
# attn_one_kv_head=True,
# qk_norm=True,
# attn_qk_norm=True,
# attn_qk_norm_dim_scale=True,
# embedding_provider=AndromedaEmbedding()
# )
model = Andromeda1Billion()
print_num_params(model, accelerator)
if CFG.USE_FSDP:
model = fsdp(
model,
mp="fp16",
shard_strat="SHARD_GRAD"
)
if CFG.USE_ACTIVATION_CHECKPOINTING:
activation_checkpointing(model, accelerator)
model = accelerator.prepare(model)
# dataloaders
if CFG.USE_PRETOKENIZED:
train_dataset = build_pre_tokenized()
else:
train_dataset = build_dataloaders()
train_loader = DataLoader(
train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
)
# optimizer
optim = decoupled_optimizer(
model=model,
learning_rate=CFG.LEARNING_RATE,
weight_decay=CFG.WEIGHT_DECAY,
beta_1=0.90,
beta_2=0.95,
optimizer_type='lion',
use_fsdp=True,
accelerator=accelerator
)
# Determine number of training steps
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
accelerator.print(f"Max train steps: {max_train_steps}")
# lr scheduler
NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
# if False: # if CFG.USE_DEEPSPEED:
# lr_scheduler = DummyScheduler(
# optim,
# total_num_steps=max_train_steps * accelerator.num_processes,
# warmup_num_steps=NUM_WARMUP_STEPS
# )
# else:
lr_scheduler = get_lr_scheduler_with_warmup(
optimizer=optim,
scheduler_type="cosine",
num_warmup_steps=NUM_WARMUP_STEPS,
max_train_steps=max_train_steps,
grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY,
)
# prepare
optim, train_loader, lr_scheduler = accelerator.prepare(
optim, train_loader, lr_scheduler
)
# checkpoint scheduler
accelerator.register_for_checkpointing(lr_scheduler)
# I do not know why Huggingface recommends recalculation of max_train_steps
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
accelerator.print(f"Max train steps recalculated: {max_train_steps}")
# Total batch size for logging
total_batch_size = (
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
)
accelerator.print(f"Total batch size: {total_batch_size}")
# resume training
progress_bar = tqdm(
range(max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
if CFG.RESUME_FROM_CHECKPOINT:
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
training_difference = os.path.splitext(path)[0]
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = (
int(training_difference.replace("step_", ""))
* CFG.GRADIENT_ACCUMULATE_EVERY
)
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
train_loader = accelerator.skip_first_batches(train_loader, resume_step)
completed_steps += resume_step
progress_bar.update(resume_step)
# training
model.train()
for step, batch in enumerate(train_loader):
with accelerator.accumulate(model):
inputs = batch["input_ids"].to(accelerator.device)
loss = model(inputs, return_loss=True)
accelerator.backward(loss)
accelerator.log({"loss": loss.item()}, step=step)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
lr_scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if isinstance(CFG.CHECKPOINTING_STEPS, int):
if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
output_dir = f"step_{completed_steps }"
if CFG.OUTPUT_DIR is not None:
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= max_train_steps:
break
#logging every CFG.LOGGING STEPS
if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0:
logger.info(
f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}"
)
# end training
# accelerator.print(f"Training Finished")
accelerator.end_training()
# save final model
# accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
if CFG.OUTPUT_DIR is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
with accelerator.main_process_first():
accelerator.save(
unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt"
)
def main():
os.environ['MASTER_ADDR'] #'localhost'
os.environ['MASTER_PORT'] #= '9994'
# # [CRITICAL] Pay attention to this when scaling to multiple GPUs and clusters
# # Pay attention to this, use "accelerate config"
os.environ['RANK'] #= str(0) # Number of nodes (servers)
os.environ['WORLD_SIZE'] # = str(torch.cuda.device_count())
dist.init_process_group(backend='nccl') #init_method="env://")
Train()
if __name__ == '__main__':
main()