File size: 6,030 Bytes
d90b3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version
#
import os
import pathlib
import subprocess
import torch
from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating different
# compilation commands (with different order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicitly in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def load(neox_args=None):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
if torch.version.hip is None:
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 1:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
if int(bare_metal_minor) >= 4:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
if int(bare_metal_minor) >= 8:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_89,code=sm_89")
if int(bare_metal_major) >= 12:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Determine verbosity
verbose = True if neox_args is None else (neox_args.rank == 0)
# Helper function to build the kernels.
def _cpp_extention_load_helper(
name, sources, extra_cuda_flags, extra_include_paths
):
if torch.version.hip is not None:
extra_cuda_cflags = ["-O3"] + extra_cuda_flags + cc_flag
else:
extra_cuda_cflags = (
["-O3", "-gencode", "arch=compute_70,code=sm_70", "--use_fast_math"]
+ extra_cuda_flags
+ cc_flag
)
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
verbose=verbose,
)
# ==============
# Fused softmax.
# ==============
if torch.version.hip is not None:
extra_include_paths = [os.path.abspath(srcpath)]
else:
extra_include_paths = []
if torch.version.hip is not None:
extra_cuda_flags = [
"-D__HIP_NO_HALF_OPERATORS__=1",
"-D__HIP_NO_HALF_CONVERSIONS__=1",
]
else:
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
# Upper triangular softmax.
sources = [
srcpath / "scaled_upper_triang_masked_softmax.cpp",
srcpath / "scaled_upper_triang_masked_softmax_cuda.cu",
]
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources,
extra_cuda_flags,
extra_include_paths,
)
# Masked softmax.
sources = [
srcpath / "scaled_masked_softmax.cpp",
srcpath / "scaled_masked_softmax_cuda.cu",
]
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths
)
# fused rope
sources = [
srcpath / "fused_rotary_positional_embedding.cpp",
srcpath / "fused_rotary_positional_embedding_cuda.cu",
]
fused_rotary_positional_embedding = _cpp_extention_load_helper(
"fused_rotary_positional_embedding",
sources,
extra_cuda_flags,
extra_include_paths,
)
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
def load_fused_kernels():
try:
import scaled_upper_triang_masked_softmax_cuda
import scaled_masked_softmax_cuda
import fused_rotary_positional_embedding
except (ImportError, ModuleNotFoundError) as e:
print("\n")
print(e)
print("=" * 100)
print(
f"ERROR: Fused kernels configured but not properly installed. Please run `from megatron.fused_kernels import load()` then `load()` to load them correctly"
)
print("=" * 100)
exit()
return
|