NEOX / megatron /mpu /mappings.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
10.5 kB
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .initialize import (
get_model_parallel_group,
get_model_parallel_world_size,
get_model_parallel_rank,
get_fp32_allreduce,
)
from .utils import split_tensor_along_last_dim, split_tensor_along_any_dim
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_model_parallel_world_size() == 1:
return input_
# upcast to fp32 if using fp32 allreduce
dt = input_.dtype
if get_fp32_allreduce():
input_ = input_.float()
# All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group())
# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
input_ = input_.to(dt)
return input_
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
def _reduce_scatter_along_seq_dim(input_, seq_dim):
"""Reduce-scatter the input tensor across model parallel group, scattering across sequence dim."""
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# upcast to fp32 if using fp32 allreduce
dt = input_.dtype
if get_fp32_allreduce():
input_ = input_.float()
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
assert dim_size[seq_dim] % world_size == 0
if seq_dim == 0:
# reduce_scatter_tensor is faster but only works correctly on dimension 0
dim_size[seq_dim] = dim_size[seq_dim] // world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(
output, tensor_list, group=get_model_parallel_group()
)
# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
output = output.to(dt)
return output
def _gather_along_seq_dim(input_, seq_dim):
"""Gather tensors and concatinate along the (manually-specified) sequence dimension."""
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
dim_size[seq_dim] = dim_size[seq_dim] * world_size
if seq_dim == 0:
# reduce_gather_tensor is faster but only works correctly on dimension 0
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(
tensor_list, input_, group=get_model_parallel_group()
)
output = torch.cat(tensor_list, dim=seq_dim)
return output
def _split_along_seq_dim(input_, seq_dim):
"""Split the tensor along the sequence dimension (as manually selected) and keep the
corresponding slice."""
world_size = get_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along second dimension.
input_list = split_tensor_along_any_dim(input_, world_size, seq_dim)
# Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank()
output = input_list[rank].contiguous()
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce-Scatter across sequence parallel region (same as model parallel region.)
Note: same region as model parallel region
"""
@staticmethod
def symbolic(graph, input_, seq_dim):
return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def forward(ctx, input_, seq_dim):
ctx.seq_dim = seq_dim
return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""All-Gather across sequence parallel region (same region as model parallel region.)"""
@staticmethod
def symbolic(graph, input_, seq_dim):
return _gather_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def forward(ctx, input_, seq_dim):
ctx.seq_dim = seq_dim
return _gather_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
return _reduce_scatter_along_seq_dim(grad_output, seq_dim=seq_dim), None
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Scatter (split) sequence length across sequence parallel region (=> same region as model parallel.)"""
@staticmethod
def symbolic(graph, input_, seq_dim):
return _split_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def forward(ctx, input_, seq_dim):
ctx.seq_dim = seq_dim
return _split_along_seq_dim(input_, seq_dim=seq_dim)
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
return (
_gather_along_seq_dim(grad_output, seq_dim=seq_dim),
None,
)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0):
return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim)
def gather_from_sequence_parallel_region(input_, seq_dim=0):
return _GatherFromSequenceParallelRegion.apply(input_, seq_dim)
def scatter_to_sequence_parallel_region(
input_, seq_dim=1
): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h] instead of the usual [s, b, h]
return _ScatterToSequenceParallelRegion.apply(input_, seq_dim)