NEOX / megatron /model /gpt2_model.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
16.5 kB
# Copyright (c) 2024 EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import math
import torch
import torch.nn as nn
from collections import defaultdict
from functools import partial
from megatron.model.utils import Lambda, SequentialWrapper, recursive_setattr
from megatron.model.norms import get_norm
from megatron.model.init_functions import get_init_methods
from megatron import mpu
from megatron.mpu import ParallelRelativePositionBias
from megatron.model.transformer import (
ParallelTransformerLayerPipe,
NormPipe,
ParallelLinearPipe,
parallel_lm_logits,
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.rwkv.v6 import RWKVResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding
# Pipeline parallelism
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
from typing import Union, List
def gpt2_attention_mask_func(attention_scores, ltor_mask):
mask_value = torch.finfo(attention_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(
mask_value, dtype=attention_scores.dtype, device=attention_scores.device
)
attention_scores.masked_fill_(ltor_mask, mask_value)
return attention_scores
def cross_entropy(output, labels, _fp16=False):
"""From pretrain_gpt2:forward_step()"""
"""
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
"""
labels, loss_mask = labels[0], labels[1]
if _fp16:
assert output.dtype == torch.half and loss_mask.dtype == torch.half
losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels)
else:
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
def _pre_transformer_block(args):
# data format change for hidden_states to avoid explicit tranposes : [b s h] --> [s b h]
assert len(args) == 2, "Incorrect number of arguments to _pre_transformer_block"
fn = lambda _args: (_args[0].transpose(0, 1).contiguous(), *_args[1:])
return fn(args)
def _post_transformer_block(args):
# from (hidden_states, attention_mask)
# to (hidden_states.T)
assert len(args) == 2, "Incorrect number of arguments to _post_transformer_block"
fn = lambda _args: (_args[0].transpose(0, 1).contiguous())
return fn(args)
class GPT2ModelPipe(PipelineModule, torch.nn.Module):
"""GPT2Model adapted for pipeline parallelism.
The largest change is flattening the GPTModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
:param neox_args: NeoX arguments object (configuration)
:param num_tokentypes: number of token types (TODO: deprecated, remove)
:param parallel_output: if true, don't gather the output logits, and calculate loss in parallel. Set to true by default in training for efficiency, but set to false for inference.
:param topology: deepspeed topology object specifying pipe / model parallelism topology.
:param use_cache: if true, cache key/value pairs for each layer in inference.
"""
def __init__(
self,
neox_args,
num_tokentypes=0,
parallel_output=True,
topology=None,
use_cache=False,
):
self.neox_args = neox_args
self.use_cache = use_cache
self.parallel_output = parallel_output
self.hidden_size = self.neox_args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method, self.output_layer_init_method = get_init_methods(
self.neox_args
)
self.__topology__ = topology
self.specs = []
self.init_specs() # initializes the layer specs (basically a fancy nn.Sequential)
super().__init__(
layers=self.specs,
loss_fn=partial(cross_entropy, _fp16=self.neox_args.fp16_lm_cross_entropy),
topology=topology,
activation_checkpoint_interval=self.neox_args.checkpoint_num_layers
if self.neox_args.checkpoint_activations
else 0,
partition_method=neox_args.pipe_partition_method,
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)
def insert_layers(
self, layers: Union[nn.Module, nn.ModuleList, nn.Sequential, List], idx
):
"""
inserts the layers in `layers` into the pipe model at `idx`.
"""
if isinstance(layers, nn.Module):
self.specs.insert(idx, layers)
elif any(
[isinstance(layers, nn.ModuleList), isinstance(layers, nn.Sequential)]
):
self.specs[idx:idx] = layers
elif isinstance(layers, list):
assert all(
[hasattr(l, "__call__") for l in layers]
), "all items in `layers` must be Callables"
self.specs[idx:idx] = layers
else:
raise ValueError(
f"layer passed into {self.__class__.__name__}.insert_layer() should be either an nn.Module, an nn.ModuleList, an nn.Sequential object, or a list of callables not a {type(layers)}"
)
# re-initialize parent class
super().__init__(
layers=self.specs,
loss_fn=self.loss_fn,
topology=self.__topology__,
activation_checkpoint_interval=self.activation_checkpoint_interval,
partition_method=self.neox_args.pipe_partition_method,
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
"RWKVResidualLayerPipe",
],
)
def init_specs(self):
weight_tying = not self.neox_args.no_weight_tying
self.specs = []
# Embedding layer
# input will be (input_ids, position_ids, attention_mask)
if weight_tying:
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
tied_weight_attr="word_embeddings_weight",
)
)
else:
self.specs.append(
LayerSpec(
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
)
)
# NB: the attention mask always needs to be the *last* item in the args when being passed from
# one stage to the next, because deepspeed is hacks on top of hacks.
#
# outputs are now (hidden_states, attention_mask)
self.specs.append(_pre_transformer_block)
# T5 RPE positional embedding
if self.neox_args.pos_emb == "rpe":
hidden_size_per_attention_head = mpu.divide(
self.neox_args.hidden_size, self.neox_args.num_attention_heads
)
rpe_scale = math.sqrt(hidden_size_per_attention_head)
rpe_emb = ParallelRelativePositionBias(
neox_args=self.neox_args,
scale=rpe_scale,
causal=True,
num_buckets=self.neox_args.rpe_num_buckets,
max_distance=self.neox_args.rpe_max_distance,
heads=self.neox_args.num_attention_heads,
)
# Transformer layers
for i in range(self.neox_args.num_layers):
layer_type = self.neox_args.attention_config[i]
if layer_type in ["gmlp", "amlp"]:
self.specs.append(
LayerSpec(
GMLPBlock,
init_method=self.init_method,
layer_number=i,
output_layer_init_method=self.output_layer_init_method,
neox_args=self.neox_args,
mask_fn=gpt2_attention_mask_func,
)
)
elif layer_type == "rwkv":
self.specs.append(
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
layer_number=i,
)
)
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
ParallelMambaResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i,
)
)
else:
self.specs.append(
LayerSpec(
ParallelTransformerLayerPipe,
neox_args=self.neox_args,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i,
rpe=rpe_emb if self.neox_args.pos_emb == "rpe" else None,
rotary=self.neox_args.pos_emb == "rotary",
use_cache=self.use_cache,
)
)
# used to drop attention mask + reshape hidden states
self.specs.append(_post_transformer_block)
# NormPipe is a (deprecated) helper class that used to be used to pass presents along the pipeline - since presents are now cached to the `TransformerLayer` class this is no longer needed
norm, eps = get_norm(self.neox_args)
self.specs.append(
LayerSpec(NormPipe, norm, self.neox_args.hidden_size, eps=eps)
)
# outputs are now a single tensor: hidden_states
def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline."""
if self.neox_args.use_mup:
# Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout.
# https://github.com/microsoft/mup/issues/6#issuecomment-1082156274
lm_output = (
lm_output
/ self.tied_modules.embed.word_embeddings.weight.infshape.width_mult()
)
logits = parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits
if weight_tying:
self.specs.append(
TiedLayerSpec(
"embed",
EmbeddingPipe,
self.neox_args,
self.hidden_size,
self.neox_args.padded_vocab_size,
self.neox_args.max_position_embeddings,
self.neox_args.hidden_dropout,
self.init_method,
self.num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr="word_embeddings_weight",
)
)
else:
self.specs.append(
LayerSpec(
ParallelLinearPipe,
neox_args=self.neox_args,
init_method=self.init_method,
parallel_output=self.parallel_output,
is_last_layer=True,
)
)
def _set_parallel_output(self, value):
# sets the parallel output value of the final layer to value
final_layer = list(self.forward_funcs)[-1]
if isinstance(final_layer, (ParallelLinearPipe, ParallelLinear)):
final_layer.final_linear.set_parallel_output(value)
def inference_mode(self, use_cache=True):
"""
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false,
so logits are gathered across model parallel ranks.
:param cache: (bool) True if you want to use caching during inference, False otherwise
"""
# first set caching to true if specified
recursive_setattr(self.forward_funcs, "use_cache", use_cache, assert_type=bool)
# then set parallel output of the final layer to false so we don't have to gather the output manually
self._set_parallel_output(False)
recursive_setattr(self.forward_funcs, "training", False)
def train_mode(self):
"""
Sets up the model for training by turning off k/v caching and setting `parallel output` of the final layer to True,
so logits are not gathered across model parallel ranks, and loss is computed in parallel (more efficient).
"""
# set caching to false
recursive_setattr(self.forward_funcs, "use_cache", False)
# then set parallel output to true (more efficient training)
self._set_parallel_output(True)
recursive_setattr(self.forward_funcs, "training", True)
def clear_cache(self):
"""
Recursively clears the kv cache on all layers
"""
recursive_setattr(self.forward_funcs, "layer_past", None)
def to_sequential(self):
"""
Transforms the PipelineModule to a plain nn.Sequential module
:return:
"""
layers = []
tied_layers = defaultdict(list)
for n, spec in enumerate(self.specs):
if isinstance(spec, TiedLayerSpec):
if spec.key in tied_layers:
# receiver
layers.append(
Lambda(lambda x: spec.forward_fn(tied_layers[spec.key][0], x))
)
else:
# owner
module = spec.build(log=False)
layers.append(module)
tied_layers[spec.key].append(module)
elif isinstance(spec, LayerSpec):
layers.append(spec.build(log=False))
elif hasattr(spec, "__call__"):
# check that it's a callable function
layers.append(Lambda(spec))
else:
raise ValueError(f"Layer number {n} ({spec}) Not recognized")
model = SequentialWrapper(
layers,
self.activation_checkpoint_interval,
self.activation_checkpoint_func,
parent_class_name=self.__class__.__name__,
)
return model