|
import math |
|
import multiprocessing |
|
import os |
|
from datetime import timedelta |
|
from functools import partial |
|
from itertools import chain |
|
|
|
import torch |
|
|
|
|
|
from torch.distributed.fsdp import ( |
|
FullyShardedDataParallel, |
|
MixedPrecision, |
|
BackwardPrefetch, |
|
ShardingStrategy, |
|
) |
|
from accelerate import Accelerator |
|
from accelerate.utils import (DummyOptim, InitProcessGroupKwargs) |
|
from accelerate.logging import get_logger |
|
|
|
|
|
from datasets import load_dataset |
|
from lion_pytorch import Lion |
|
from torch.nn import LayerNorm |
|
|
|
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) |
|
from torch.distributed.fsdp.wrap import ( |
|
transformer_auto_wrap_policy |
|
) |
|
|
|
|
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
from transformers import (AutoTokenizer, default_data_collator, |
|
get_cosine_schedule_with_warmup, |
|
get_linear_schedule_with_warmup, set_seed) |
|
|
|
|
|
from Andromeda.utils.stable_adamw import StableAdamWUnfused |
|
from Andromeda.core.transformer import Transformer, AndromedaEmbedding |
|
|
|
from Andromeda.model import AndromedaEmbedding |
|
from Andromeda.configs import Andromeda1Billion |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
from accelerate.state import AcceleratorState |
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
class CFG: |
|
BATCH_SIZE = 1 |
|
GRADIENT_ACCUMULATE_EVERY: int = 1 |
|
SEED: int = 42 |
|
LEARNING_RATE: float = 1e-4 |
|
WEIGHT_DECAY: float = 0.1 |
|
SEQ_LEN: int = 8192 |
|
NUM_CPU: int = multiprocessing.cpu_count() |
|
USE_DEEPSPEED: bool = True |
|
USE_FSDP: bool = True |
|
USE_PRETOKENIZED: bool = True |
|
USE_ACTIVATION_CHECKPOINTING: bool = True |
|
RESUME_FROM_CHECKPOINT: str = False |
|
CHECKPOINTING_STEPS: int = 1000 |
|
OUTPUT_DIR: str = 'checkpoints/' |
|
ENTITY_NAME: str = "Andromeda" |
|
LOGGING_STEPS: int = 100 |
|
|
|
|
|
|
|
|
|
|
|
def print_num_params(model, accelerator: Accelerator): |
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
accelerator.print(f"Number of parameters in model: {n_params}") |
|
|
|
|
|
|
|
|
|
|
|
def activation_checkpointing( |
|
model: torch.nn.Module, |
|
offload_to_cpu: bool = False, |
|
accelerator: Accelerator = None, |
|
): |
|
""" |
|
Apply activation checkpointing to a model. |
|
|
|
Args: |
|
model (Module): The model to which to apply activation checkpointing. |
|
offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False. |
|
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. |
|
""" |
|
if accelerator is not None: |
|
accelerator.print("Using activation checkpointing") |
|
def check_fn(submodule): |
|
return isinstance(submodule, Transformer) |
|
non_reentrant_wrapper = partial( |
|
checkpoint_wrapper, |
|
offload_to_cpu=offload_to_cpu, |
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
|
) |
|
apply_activation_checkpointing( |
|
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def fsdp( |
|
model: torch.nn.Module, |
|
auto_wrap: bool = False, |
|
mp: str = "fp32", |
|
shard_strat: str = "NO_SHARD", |
|
): |
|
""" |
|
This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding. |
|
|
|
Args: |
|
model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP. |
|
auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False. |
|
mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'. |
|
shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'. |
|
|
|
Raises: |
|
ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'. |
|
ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'. |
|
|
|
Returns: |
|
torch.nn.Module: The input model wrapped with FSDP. |
|
""" |
|
if auto_wrap: |
|
Andromeda_auto_wrap_policy = partial( |
|
transformer_auto_wrap_policy, |
|
transformer_layer_cls={ |
|
Transformer, |
|
}, |
|
) |
|
else: |
|
Andromeda_auto_wrap_policy = None |
|
|
|
if mp == "bf16": |
|
mp_fsdp = MixedPrecision( |
|
param_dtype=torch.bfloat16, |
|
|
|
reduce_dtype=torch.bfloat16, |
|
|
|
buffer_dtype=torch.bfloat16, |
|
) |
|
elif mp == "fp16": |
|
mp_fsdp = MixedPrecision( |
|
param_dtype=torch.float16, |
|
|
|
reduce_dtype=torch.float16, |
|
|
|
buffer_dtype=torch.float16, |
|
) |
|
elif mp == "fp32": |
|
mp_fsdp = MixedPrecision( |
|
param_dtype=torch.float32, |
|
|
|
reduce_dtype=torch.float32, |
|
|
|
buffer_dtype=torch.float32, |
|
) |
|
else: |
|
raise ValueError( |
|
"Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format( |
|
mp |
|
) |
|
) |
|
|
|
if shard_strat == "SHARD_GRAD": |
|
sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP |
|
elif shard_strat == "FULL_SHARD": |
|
sharding_strat_fsdp = ShardingStrategy.FULL_SHARD |
|
elif shard_strat == "NO_SHARD": |
|
sharding_strat_fsdp = ShardingStrategy.NO_SHARD |
|
else: |
|
raise ValueError( |
|
"Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format( |
|
shard_strat |
|
) |
|
) |
|
|
|
model = FullyShardedDataParallel( |
|
model, |
|
auto_wrap_policy=Andromeda_auto_wrap_policy, |
|
mixed_precision=mp_fsdp, |
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
|
sharding_strategy=sharding_strat_fsdp, |
|
forward_prefetch=True, |
|
use_orig_params=True, |
|
) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def get_lr_scheduler_with_warmup( |
|
optimizer: torch.optim.Optimizer, |
|
scheduler_type: str, |
|
num_warmup_steps: int, |
|
max_train_steps: int, |
|
grad_accumulate_every: int = 1, |
|
accelerator: Accelerator = None, |
|
): |
|
""" |
|
Get a learning rate scheduler with warmup. |
|
|
|
Args: |
|
optimizer (Optimizer): The optimizer for which to create the learning rate scheduler. |
|
scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine". |
|
num_warmup_steps (int): The number of warmup steps for the learning rate scheduler. |
|
max_train_steps (int): The maximum number of training steps. |
|
grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1. |
|
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. |
|
|
|
Returns: |
|
The learning rate scheduler with warmup. |
|
|
|
Raises: |
|
ValueError: If scheduler_type is not "linear" or "cosine". |
|
""" |
|
NUM_WARMUP_STEPS = num_warmup_steps |
|
GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every |
|
if accelerator is not None: |
|
accelerator.print(f"Using {scheduler_type} lr scheduler") |
|
if scheduler_type == "linear": |
|
return get_linear_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, |
|
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, |
|
) |
|
elif scheduler_type == "cosine": |
|
return get_cosine_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, |
|
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, |
|
) |
|
else: |
|
raise ValueError( |
|
"Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format( |
|
scheduler_type |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def decoupled_optimizer( |
|
model: torch.nn.Module, |
|
learning_rate: float, |
|
weight_decay: float, |
|
beta_1: float, |
|
beta_2: float, |
|
optimizer_type: str, |
|
use_fsdp: bool = True, |
|
accelerator: Accelerator = None, |
|
): |
|
""" |
|
Decouples the optimizer from the training process. |
|
|
|
This function sets up the optimizer for the model by creating two groups of parameters: |
|
one for weight decay and one without weight decay. Then, it initializes the optimizer |
|
with these two groups of parameters. |
|
|
|
Args: |
|
model (Module): The model whose parameters are optimized. |
|
learning_rate (float): The learning rate for the optimizer. |
|
weight_decay (float): The weight decay for the optimizer. |
|
beta_1 (float): The exponential decay rate for the 1st moment estimates. |
|
beta_2 (float): The exponential decay rate for the 2nd moment estimates. |
|
optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'. |
|
use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True. |
|
accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None. |
|
|
|
Returns: |
|
Optimizer: The initialized optimizer. |
|
|
|
Raises: |
|
ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'. |
|
""" |
|
accelerator.print(f"Using {optimizer_type} optimizer") |
|
|
|
param_dict = {} |
|
|
|
for param_name, param in model.named_parameters(): |
|
param_dict[param_name] = param |
|
|
|
|
|
|
|
|
|
no_decay = [] |
|
|
|
if use_fsdp: |
|
exclude_module = "_fsdp_wrapped_module.token_emb" |
|
else: |
|
exclude_module = "token_emb" |
|
|
|
|
|
for module_name, module in model.named_modules(): |
|
|
|
for ndim in [LayerNorm, torch.nn.Embedding]: |
|
if isinstance(module, ndim): |
|
|
|
if module_name == exclude_module: |
|
no_decay.append(f"{module_name}.weight") |
|
else: |
|
|
|
no_decay.append(f"{module_name}.gamma") |
|
|
|
break |
|
|
|
|
|
decay = [] |
|
|
|
|
|
for module_name, module in model.named_modules(): |
|
|
|
for ndim in [torch.nn.Linear]: |
|
if isinstance(module, ndim): |
|
|
|
decay.append(f"{module_name}.weight") |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
decay_param = [] |
|
|
|
if use_fsdp: |
|
exclude_param = "_fsdp_wrapped_module.to_logits.weight" |
|
else: |
|
exclude_param = "to_logits.weight" |
|
|
|
|
|
for param in decay: |
|
|
|
|
|
|
|
if param != exclude_param: |
|
decay_param.append(param_dict[param]) |
|
|
|
|
|
no_decay_param = [] |
|
|
|
|
|
for param in no_decay: |
|
try: |
|
|
|
|
|
no_decay_param.append(param_dict[param]) |
|
except KeyError: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
grouped_params = [ |
|
{"params": decay_param, "weight_decay": weight_decay}, |
|
{"params": no_decay_param, "weight_decay": 0.0}, |
|
] |
|
|
|
|
|
if optimizer_type == "lion": |
|
optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
|
elif optimizer_type == "adamw": |
|
optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
|
elif optimizer_type == "deepspeed": |
|
optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
|
elif optimizer_type == "stable_adamw": |
|
optimizer = StableAdamWUnfused( |
|
grouped_params, lr=learning_rate, betas=(beta_1, beta_2), |
|
) |
|
|
|
|
|
|
|
|
|
else: |
|
raise ValueError( |
|
"Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format( |
|
optimizer_type |
|
) |
|
) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
def build_dataloaders(): |
|
""" |
|
Build data loaders for training. |
|
|
|
This function performs the following steps: |
|
1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model. |
|
2. Load the "openwebtext" dataset. |
|
3. Tokenize the dataset, adding the end-of-sentence token to each text. |
|
4. Process the tokenized dataset into chunks of a specified block size. |
|
|
|
Returns: |
|
Dataset: The processed dataset ready for training. |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
dataset = load_dataset("openwebtext", split="train") |
|
|
|
tokenized_dataset = dataset.map( |
|
lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]), |
|
batched=True, |
|
num_proc=CFG.NUM_CPU, |
|
remove_columns=["text"], |
|
) |
|
|
|
block_size = CFG.SEQ_LEN |
|
|
|
|
|
def group_texts(examples): |
|
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
if total_length >= block_size: |
|
total_length = (total_length // block_size) * block_size |
|
|
|
result = { |
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
return result |
|
|
|
train_dataset = tokenized_dataset.map( |
|
group_texts, batched=True, num_proc=CFG.NUM_CPU, |
|
) |
|
|
|
return train_dataset |
|
|
|
|
|
def build_pre_tokenized(): |
|
d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]") |
|
|
|
|
|
|
|
|
|
|
|
return d0 |
|
|
|
|
|
|
|
def Train(): |
|
|
|
|
|
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) |
|
|
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY, |
|
mixed_precision="fp16", |
|
log_with="wandb", |
|
kwargs_handlers=[timeout], |
|
) |
|
|
|
state = AcceleratorState() |
|
|
|
state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE |
|
|
|
accelerator.init_trackers( |
|
project_name="Andromeda", |
|
config={ |
|
"batch_size": CFG.BATCH_SIZE, |
|
"gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY, |
|
"learning_rate": CFG.LEARNING_RATE, |
|
"seq_len": CFG.SEQ_LEN, |
|
}, |
|
|
|
) |
|
|
|
accelerator.print(f"Total GPUS: {accelerator.num_processes}") |
|
|
|
|
|
|
|
set_seed(CFG.SEED) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Andromeda1Billion() |
|
|
|
print_num_params(model, accelerator) |
|
|
|
if CFG.USE_FSDP: |
|
model = fsdp( |
|
model, |
|
mp="fp16", |
|
shard_strat="SHARD_GRAD" |
|
) |
|
|
|
if CFG.USE_ACTIVATION_CHECKPOINTING: |
|
activation_checkpointing(model, accelerator) |
|
|
|
model = accelerator.prepare(model) |
|
|
|
|
|
|
|
if CFG.USE_PRETOKENIZED: |
|
train_dataset = build_pre_tokenized() |
|
else: |
|
train_dataset = build_dataloaders() |
|
|
|
train_loader = DataLoader( |
|
train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator, |
|
) |
|
|
|
|
|
|
|
optim = decoupled_optimizer( |
|
model=model, |
|
learning_rate=CFG.LEARNING_RATE, |
|
weight_decay=CFG.WEIGHT_DECAY, |
|
beta_1=0.90, |
|
beta_2=0.95, |
|
optimizer_type='lion', |
|
use_fsdp=True, |
|
accelerator=accelerator |
|
) |
|
|
|
|
|
|
|
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) |
|
accelerator.print(f"Max train steps: {max_train_steps}") |
|
|
|
|
|
|
|
NUM_WARMUP_STEPS = int(max_train_steps * 0.01) |
|
accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lr_scheduler = get_lr_scheduler_with_warmup( |
|
optimizer=optim, |
|
scheduler_type="cosine", |
|
num_warmup_steps=NUM_WARMUP_STEPS, |
|
max_train_steps=max_train_steps, |
|
grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY, |
|
) |
|
|
|
|
|
|
|
optim, train_loader, lr_scheduler = accelerator.prepare( |
|
optim, train_loader, lr_scheduler |
|
) |
|
|
|
|
|
|
|
accelerator.register_for_checkpointing(lr_scheduler) |
|
|
|
|
|
|
|
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) |
|
accelerator.print(f"Max train steps recalculated: {max_train_steps}") |
|
|
|
|
|
|
|
total_batch_size = ( |
|
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY |
|
) |
|
accelerator.print(f"Total batch size: {total_batch_size}") |
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
range(max_train_steps), disable=not accelerator.is_local_main_process |
|
) |
|
completed_steps = 0 |
|
|
|
if CFG.RESUME_FROM_CHECKPOINT: |
|
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "": |
|
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}") |
|
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT) |
|
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT) |
|
training_difference = os.path.splitext(path)[0] |
|
|
|
|
|
resume_step = ( |
|
int(training_difference.replace("step_", "")) |
|
* CFG.GRADIENT_ACCUMULATE_EVERY |
|
) |
|
|
|
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None: |
|
train_loader = accelerator.skip_first_batches(train_loader, resume_step) |
|
completed_steps += resume_step |
|
progress_bar.update(resume_step) |
|
|
|
|
|
|
|
model.train() |
|
for step, batch in enumerate(train_loader): |
|
with accelerator.accumulate(model): |
|
inputs = batch["input_ids"].to(accelerator.device) |
|
loss = model(inputs, return_loss=True) |
|
accelerator.backward(loss) |
|
|
|
accelerator.log({"loss": loss.item()}, step=step) |
|
|
|
if accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
optim.step() |
|
lr_scheduler.step() |
|
optim.zero_grad() |
|
|
|
if accelerator.sync_gradients: |
|
progress_bar.update(1) |
|
completed_steps += 1 |
|
|
|
if isinstance(CFG.CHECKPOINTING_STEPS, int): |
|
if completed_steps % CFG.CHECKPOINTING_STEPS == 0: |
|
output_dir = f"step_{completed_steps }" |
|
if CFG.OUTPUT_DIR is not None: |
|
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir) |
|
accelerator.save_state(output_dir) |
|
|
|
if completed_steps >= max_train_steps: |
|
break |
|
|
|
|
|
if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0: |
|
logger.info( |
|
f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}" |
|
) |
|
|
|
|
|
|
|
|
|
accelerator.end_training() |
|
|
|
|
|
|
|
|
|
if CFG.OUTPUT_DIR is not None: |
|
accelerator.wait_for_everyone() |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
with accelerator.main_process_first(): |
|
accelerator.save( |
|
unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt" |
|
) |
|
|
|
|
|
def main(): |
|
os.environ['MASTER_ADDR'] |
|
os.environ['MASTER_PORT'] |
|
|
|
|
|
|
|
|
|
|
|
os.environ['RANK'] |
|
os.environ['WORLD_SIZE'] |
|
|
|
dist.init_process_group(backend='nccl') |
|
|
|
Train() |
|
|
|
if __name__ == '__main__': |
|
main() |