Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/__init__.py +96 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/_common.py +268 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/_ipc_utils.py +139 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/_utils.py +525 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/__init__.py +9 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/auto_parallel.py +263 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/cluster_info.py +556 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/config.py +61 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/device_mesh.py +612 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/node_graph.py +347 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py +0 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/pipeline_graph.py +1035 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/runtime_profiling.py +150 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/shape_info.py +362 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/simplifier.py +837 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/solver.py +641 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py +0 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py +41 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py +34 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py +45 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py +58 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py +56 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py +45 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py +49 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py +59 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py +196 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py +56 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py +79 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py +798 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/node.py +376 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py +60 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py +79 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py +67 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py +40 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py +0 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py +27 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py +395 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py +11 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py +19 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py +28 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py +73 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py +56 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_consistency.py +832 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_node.py +41 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_spec.py +418 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_strategy.py +77 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shuffle_node.py +238 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/slice_node.py +100 -0
- lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/softmax_node.py +54 -0
.gitattributes
CHANGED
@@ -34,3 +34,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
tensorrt_llm-0.12.0.dev2024072300-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
tensorrt_llm-0.12.0.dev2024072300-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/bin/executorWorker filter=lfs diff=lfs merge=lfs -text
|
38 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
39 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libdecoder_attention.so filter=lfs diff=lfs merge=lfs -text
|
40 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so filter=lfs diff=lfs merge=lfs -text
|
41 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libtensorrt_llm.so filter=lfs diff=lfs merge=lfs -text
|
42 |
+
lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libtensorrt_llm_nvrtc_wrapper.so filter=lfs diff=lfs merge=lfs -text
|
lib.linux-x86_64-cpython-310/tensorrt_llm/__init__.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
def _add_trt_llm_dll_directory():
|
18 |
+
import platform
|
19 |
+
on_windows = platform.system() == "Windows"
|
20 |
+
if on_windows:
|
21 |
+
import os
|
22 |
+
import sysconfig
|
23 |
+
from pathlib import Path
|
24 |
+
os.add_dll_directory(
|
25 |
+
Path(sysconfig.get_paths()['purelib']) / "tensorrt_llm" / "libs")
|
26 |
+
|
27 |
+
|
28 |
+
_add_trt_llm_dll_directory()
|
29 |
+
|
30 |
+
import sys
|
31 |
+
|
32 |
+
import tensorrt_llm.functional as functional
|
33 |
+
import tensorrt_llm.models as models
|
34 |
+
import tensorrt_llm.quantization as quantization
|
35 |
+
import tensorrt_llm.runtime as runtime
|
36 |
+
import tensorrt_llm.tools as tools
|
37 |
+
|
38 |
+
from ._common import _init, default_net, default_trtnet, precision
|
39 |
+
# Disable flake8 on the line below because mpi_barrier is not used in tensorrt_llm project
|
40 |
+
# but may be called in dependencies (such as examples)
|
41 |
+
from ._utils import mpi_barrier # NOQA
|
42 |
+
from ._utils import str_dtype_to_torch # NOQA
|
43 |
+
from ._utils import (mpi_rank, mpi_world_size, str_dtype_to_trt,
|
44 |
+
torch_dtype_to_trt)
|
45 |
+
from .auto_parallel import AutoParallelConfig, auto_parallel
|
46 |
+
from .builder import BuildConfig, Builder, BuilderConfig, build
|
47 |
+
from .functional import Tensor, constant
|
48 |
+
from .hlapi.llm import LLM, LlmArgs, SamplingParams
|
49 |
+
from .logger import logger
|
50 |
+
from .mapping import Mapping
|
51 |
+
from .module import Module
|
52 |
+
from .network import Network, net_guard
|
53 |
+
from .parameter import Parameter
|
54 |
+
from .version import __version__
|
55 |
+
|
56 |
+
__all__ = [
|
57 |
+
'logger',
|
58 |
+
'str_dtype_to_trt',
|
59 |
+
'torch_dtype_to_trt',
|
60 |
+
'str_dtype_to_torch'
|
61 |
+
'mpi_barrier',
|
62 |
+
'mpi_rank',
|
63 |
+
'mpi_world_size',
|
64 |
+
'constant',
|
65 |
+
'default_net',
|
66 |
+
'default_trtnet',
|
67 |
+
'precision',
|
68 |
+
'net_guard',
|
69 |
+
'Network',
|
70 |
+
'Mapping',
|
71 |
+
'Builder',
|
72 |
+
'BuilderConfig',
|
73 |
+
'build',
|
74 |
+
'BuildConfig',
|
75 |
+
'Tensor',
|
76 |
+
'Parameter',
|
77 |
+
'runtime',
|
78 |
+
'Module',
|
79 |
+
'functional',
|
80 |
+
'models',
|
81 |
+
'auto_parallel',
|
82 |
+
'AutoParallelConfig',
|
83 |
+
'quantization',
|
84 |
+
'tools',
|
85 |
+
'LLM',
|
86 |
+
'LlmArgs',
|
87 |
+
'SamplingParams',
|
88 |
+
'KvCacheConfig',
|
89 |
+
'__version__',
|
90 |
+
]
|
91 |
+
|
92 |
+
_init(log_level="error")
|
93 |
+
|
94 |
+
print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}")
|
95 |
+
|
96 |
+
sys.stdout.flush()
|
lib.linux-x86_64-cpython-310/tensorrt_llm/_common.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import contextlib
|
16 |
+
import ctypes
|
17 |
+
import os
|
18 |
+
import platform
|
19 |
+
import time
|
20 |
+
from functools import wraps
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
# isort: off
|
26 |
+
import torch
|
27 |
+
import tensorrt as trt
|
28 |
+
|
29 |
+
# isort: on
|
30 |
+
|
31 |
+
from ._utils import str_dtype_to_trt
|
32 |
+
from .bindings import MpiComm
|
33 |
+
from .logger import logger
|
34 |
+
from .plugin import _load_plugin_lib
|
35 |
+
|
36 |
+
net = None
|
37 |
+
|
38 |
+
_inited = False
|
39 |
+
|
40 |
+
|
41 |
+
def _init(log_level: object = None) -> None:
|
42 |
+
global _inited
|
43 |
+
if _inited:
|
44 |
+
return
|
45 |
+
_inited = True
|
46 |
+
# Move to __init__
|
47 |
+
if log_level is not None:
|
48 |
+
logger.set_level(log_level)
|
49 |
+
|
50 |
+
if os.getenv("TRT_LLM_NO_LIB_INIT", "0") == "1":
|
51 |
+
logger.info('Skipping TensorRT-LLM init.')
|
52 |
+
return
|
53 |
+
|
54 |
+
logger.info('Starting TensorRT-LLM init.')
|
55 |
+
|
56 |
+
# load plugin lib
|
57 |
+
_load_plugin_lib()
|
58 |
+
|
59 |
+
# load FT decoder layer
|
60 |
+
project_dir = str(Path(__file__).parent.absolute())
|
61 |
+
if platform.system() == "Windows":
|
62 |
+
ft_decoder_lib = project_dir + '/libs/th_common.dll'
|
63 |
+
else:
|
64 |
+
ft_decoder_lib = project_dir + '/libs/libth_common.so'
|
65 |
+
try:
|
66 |
+
torch.classes.load_library(ft_decoder_lib)
|
67 |
+
except Exception as e:
|
68 |
+
msg = '\nFATAL: Decoding operators failed to load. This may be caused by the incompatibility between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM.'
|
69 |
+
raise ImportError(str(e) + msg)
|
70 |
+
|
71 |
+
MpiComm.local_init()
|
72 |
+
|
73 |
+
logger.info('TensorRT-LLM inited.')
|
74 |
+
|
75 |
+
|
76 |
+
def default_net():
|
77 |
+
assert net, "Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
|
78 |
+
return net
|
79 |
+
|
80 |
+
|
81 |
+
def default_trtnet():
|
82 |
+
return default_net().trt_network
|
83 |
+
|
84 |
+
|
85 |
+
def set_network(network):
|
86 |
+
global net
|
87 |
+
net = network
|
88 |
+
|
89 |
+
|
90 |
+
def switch_net_dtype(cur_dtype):
|
91 |
+
prev_dtype = default_net().dtype
|
92 |
+
default_net().dtype = cur_dtype
|
93 |
+
return prev_dtype
|
94 |
+
|
95 |
+
|
96 |
+
@contextlib.contextmanager
|
97 |
+
def precision(dtype):
|
98 |
+
if isinstance(dtype, str):
|
99 |
+
dtype = str_dtype_to_trt(dtype)
|
100 |
+
prev_dtype = switch_net_dtype(dtype)
|
101 |
+
yield
|
102 |
+
switch_net_dtype(prev_dtype)
|
103 |
+
|
104 |
+
|
105 |
+
def serialize_engine(engine, path):
|
106 |
+
logger.info(f'Serializing engine to {path}...')
|
107 |
+
tik = time.time()
|
108 |
+
if isinstance(engine, trt.ICudaEngine):
|
109 |
+
engine = engine.serialize()
|
110 |
+
with open(path, 'wb') as f:
|
111 |
+
f.write(engine)
|
112 |
+
tok = time.time()
|
113 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
114 |
+
logger.info(f'Engine serialized. Total time: {t}')
|
115 |
+
|
116 |
+
|
117 |
+
def deserialize_engine(path):
|
118 |
+
runtime = trt.Runtime(logger.trt_logger)
|
119 |
+
with open(path, 'rb') as f:
|
120 |
+
logger.info(f'Loading engine from {path}...')
|
121 |
+
tik = time.time()
|
122 |
+
|
123 |
+
engine = runtime.deserialize_cuda_engine(f.read())
|
124 |
+
assert engine is not None
|
125 |
+
|
126 |
+
tok = time.time()
|
127 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
128 |
+
logger.info(f'Engine loaded. Total time: {t}')
|
129 |
+
return engine
|
130 |
+
|
131 |
+
|
132 |
+
_field_dtype_to_np_dtype_dict = {
|
133 |
+
trt.PluginFieldType.FLOAT16: np.float16,
|
134 |
+
trt.PluginFieldType.FLOAT32: np.float32,
|
135 |
+
trt.PluginFieldType.FLOAT64: np.float64,
|
136 |
+
trt.PluginFieldType.INT8: np.int8,
|
137 |
+
trt.PluginFieldType.INT16: np.int16,
|
138 |
+
trt.PluginFieldType.INT32: np.int32,
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
def field_dtype_to_np_dtype(dtype):
|
143 |
+
ret = _field_dtype_to_np_dtype_dict.get(dtype)
|
144 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
145 |
+
return ret
|
146 |
+
|
147 |
+
|
148 |
+
def convert_capsule_to_void_p(capsule):
|
149 |
+
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
|
150 |
+
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
|
151 |
+
ctypes.py_object, ctypes.c_char_p
|
152 |
+
]
|
153 |
+
return ctypes.pythonapi.PyCapsule_GetPointer(capsule, None)
|
154 |
+
|
155 |
+
|
156 |
+
def get_nparray_from_void_p(void_pointer, elem_size, field_dtype):
|
157 |
+
ctypes.pythonapi.PyMemoryView_FromMemory.restype = ctypes.py_object
|
158 |
+
ctypes.pythonapi.PyMemoryView_FromMemory.argtypes = [
|
159 |
+
ctypes.c_char_p, ctypes.c_ssize_t, ctypes.c_int
|
160 |
+
]
|
161 |
+
logger.info(
|
162 |
+
f'get_nparray: pointer = {void_pointer}, elem_size = {elem_size}')
|
163 |
+
char_pointer = ctypes.cast(void_pointer, ctypes.POINTER(ctypes.c_char))
|
164 |
+
np_dtype = field_dtype_to_np_dtype(field_dtype)
|
165 |
+
buf_bytes = elem_size * np.dtype(np_dtype).itemsize
|
166 |
+
logger.info(f'get_nparray: buf_bytes = {buf_bytes}')
|
167 |
+
mem_view = ctypes.pythonapi.PyMemoryView_FromMemory(
|
168 |
+
char_pointer, buf_bytes, 0) # number 0 represents PyBUF_READ
|
169 |
+
logger.info(
|
170 |
+
f'get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}')
|
171 |
+
buf = np.frombuffer(mem_view, np_dtype)
|
172 |
+
return buf
|
173 |
+
|
174 |
+
|
175 |
+
def get_scalar_from_field(field):
|
176 |
+
void_p = convert_capsule_to_void_p(field.data)
|
177 |
+
np_array = get_nparray_from_void_p(void_p, 1, field.type)
|
178 |
+
return np_array[0]
|
179 |
+
|
180 |
+
|
181 |
+
class _BuildingFlag:
|
182 |
+
|
183 |
+
def __enter__(self):
|
184 |
+
os.environ['IS_BUILDING'] = '1'
|
185 |
+
os.environ[
|
186 |
+
'__LUNOWUD'] = '-cask_fusion:fp8=off' # will be removed in future releases
|
187 |
+
|
188 |
+
def __exit__(self, type, value, tb):
|
189 |
+
del os.environ['IS_BUILDING']
|
190 |
+
del os.environ['__LUNOWUD']
|
191 |
+
|
192 |
+
|
193 |
+
def _is_building(f):
|
194 |
+
'''Use this to decorate functions which are called during engine building/refitting process,
|
195 |
+
otherwise, the plugin registration will fail.
|
196 |
+
'''
|
197 |
+
|
198 |
+
@wraps(f)
|
199 |
+
def decorated(*args, **kwargs):
|
200 |
+
with _BuildingFlag():
|
201 |
+
return f(*args, **kwargs)
|
202 |
+
|
203 |
+
return decorated
|
204 |
+
|
205 |
+
|
206 |
+
def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
|
207 |
+
max_input_len, max_seq_len, max_beam_width,
|
208 |
+
remove_input_padding, enable_context_fmha,
|
209 |
+
tokens_per_block, multiple_profiles):
|
210 |
+
if not remove_input_padding:
|
211 |
+
if max_num_tokens is not None or opt_num_tokens is not None:
|
212 |
+
max_num_tokens = max_batch_size * max_seq_len
|
213 |
+
logger.warning("remove_input_padding is not enabled, the specified "
|
214 |
+
"max_num_tokens/opt_num_tokens will be ignored.")
|
215 |
+
return max_num_tokens, opt_num_tokens
|
216 |
+
else:
|
217 |
+
if max_num_tokens is None:
|
218 |
+
max_num_tokens = max_seq_len * max_batch_size
|
219 |
+
logger.warning(
|
220 |
+
"remove_input_padding is enabled, while max_num_tokens "
|
221 |
+
"is not set, setting to max_batch_size*max_seq_len. \n"
|
222 |
+
"It may not be optimal to set max_num_tokens=max_batch_size*max_seq_len "
|
223 |
+
"when remove_input_padding is enabled, because the number "
|
224 |
+
"of packed input tokens are very likely to be smaller, "
|
225 |
+
"we strongly recommend to set max_num_tokens according "
|
226 |
+
"to your workloads.")
|
227 |
+
if opt_num_tokens is None and not multiple_profiles:
|
228 |
+
opt_num_tokens = min(max_batch_size * max_beam_width,
|
229 |
+
max_num_tokens)
|
230 |
+
logger.warning(
|
231 |
+
"remove_input_padding is enabled, while opt_num_tokens "
|
232 |
+
"is not set, setting to max_batch_size*max_beam_width. \n")
|
233 |
+
if max_num_tokens > 16384:
|
234 |
+
logger.warning(
|
235 |
+
"Specifying a `max_num_tokens` larger than 16384 is usually "
|
236 |
+
"not recommended, we do not expect perf gain with that and too "
|
237 |
+
"large `max_num_tokens` could possibly exceed the TensorRT "
|
238 |
+
"tensor volume, causing runtime errors. "
|
239 |
+
f"Got `max_num_tokens` = {max_num_tokens}")
|
240 |
+
if max_num_tokens > max_seq_len * max_batch_size:
|
241 |
+
max_num_tokens = max_seq_len * max_batch_size
|
242 |
+
logger.warning(
|
243 |
+
f"max_num_tokens ({max_num_tokens}) shouldn't be greater than "
|
244 |
+
f"max_seq_len * max_batch_size ({max_seq_len * max_batch_size}), "
|
245 |
+
f"specifying to max_seq_len * max_batch_size ({max_seq_len * max_batch_size})."
|
246 |
+
)
|
247 |
+
if max_num_tokens < max_input_len and not enable_context_fmha:
|
248 |
+
logger.warning(
|
249 |
+
f"When enable_context_fmha is not turned on, max_num_tokens ({max_num_tokens}) "
|
250 |
+
f"should be at least max_input_len ({max_input_len}), specifying to "
|
251 |
+
f"max_input_len ({max_input_len}).")
|
252 |
+
max_num_tokens = max_input_len
|
253 |
+
elif max_num_tokens < tokens_per_block and enable_context_fmha:
|
254 |
+
logger.warning(
|
255 |
+
f"When enable_context_fmha is turned on, max_num_tokens ({max_num_tokens}) "
|
256 |
+
f"should be at least tokens_per_block ({tokens_per_block}), specifying to "
|
257 |
+
f"tokens_per_block ({tokens_per_block}). At this time, you also need to enable "
|
258 |
+
f"context chunking at runtime, otherwise you may encounter errors.")
|
259 |
+
max_num_tokens = tokens_per_block
|
260 |
+
|
261 |
+
if opt_num_tokens is not None and opt_num_tokens > max_num_tokens:
|
262 |
+
logger.warning(
|
263 |
+
f"opt_num_tokens ({opt_num_tokens}) shouldn't be greater than "
|
264 |
+
f"max_num_tokens ({max_num_tokens}), "
|
265 |
+
f"specifying to max_num_tokens ({max_num_tokens}).")
|
266 |
+
opt_num_tokens = max_num_tokens
|
267 |
+
|
268 |
+
return max_num_tokens, opt_num_tokens
|
lib.linux-x86_64-cpython-310/tensorrt_llm/_ipc_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import array
|
16 |
+
import struct
|
17 |
+
import sys
|
18 |
+
from contextlib import contextmanager
|
19 |
+
from typing import List, Tuple
|
20 |
+
|
21 |
+
from cuda import cudart
|
22 |
+
from cuda.cudart import cudaError_t
|
23 |
+
|
24 |
+
from ._utils import mpi_comm
|
25 |
+
from .mapping import Mapping
|
26 |
+
|
27 |
+
|
28 |
+
def _raise_if_error(error: cudaError_t):
|
29 |
+
if error != cudaError_t.cudaSuccess:
|
30 |
+
raise RuntimeError(error)
|
31 |
+
|
32 |
+
|
33 |
+
@contextmanager
|
34 |
+
def peer_access(mapping: Mapping):
|
35 |
+
set_peer_access(mapping, True)
|
36 |
+
try:
|
37 |
+
yield
|
38 |
+
finally:
|
39 |
+
set_peer_access(mapping, False)
|
40 |
+
|
41 |
+
|
42 |
+
def set_peer_access(mapping: Mapping, enabled: bool = True):
|
43 |
+
src_node = mapping.local_rank
|
44 |
+
for rank in mapping.tp_group:
|
45 |
+
dest_node = mapping.get_local_rank(rank)
|
46 |
+
if mapping.get_node_rank(
|
47 |
+
rank) != mapping.node_rank or dest_node == src_node:
|
48 |
+
continue
|
49 |
+
|
50 |
+
error, result = cudart.cudaDeviceCanAccessPeer(src_node, dest_node)
|
51 |
+
_raise_if_error(error)
|
52 |
+
|
53 |
+
if result == 0:
|
54 |
+
raise RuntimeError(
|
55 |
+
f"Can't enable access between nodes {src_node} and {dest_node}")
|
56 |
+
|
57 |
+
if enabled:
|
58 |
+
cudart.cudaDeviceEnablePeerAccess(dest_node, 0)
|
59 |
+
else:
|
60 |
+
cudart.cudaDeviceDisablePeerAccess(dest_node)
|
61 |
+
error = cudart.cudaGetLastError()[0]
|
62 |
+
if error not in [
|
63 |
+
cudaError_t.cudaSuccess,
|
64 |
+
cudaError_t.cudaErrorPeerAccessAlreadyEnabled,
|
65 |
+
cudaError_t.cudaErrorPeerAccessNotEnabled
|
66 |
+
]:
|
67 |
+
raise RuntimeError(error)
|
68 |
+
|
69 |
+
|
70 |
+
class IpcMemory():
|
71 |
+
|
72 |
+
# WARNING: Must in sync with FLAGS_SIZE in cpp/include/tensorrt_llm/runtime/ipcUtils.h
|
73 |
+
# (Max all reduce blocks + 1) * sizeof(int)
|
74 |
+
IPC_BARRIERS_SIZE_PER_GPU = (24 + 1) * 4
|
75 |
+
|
76 |
+
def __init__(self, mapping: Mapping, size: int):
|
77 |
+
self.mapping = mapping
|
78 |
+
self.open_ipc = mapping.tp_size <= mapping.gpus_per_node
|
79 |
+
if self.open_ipc:
|
80 |
+
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(
|
81 |
+
self.mapping, size, True)
|
82 |
+
else:
|
83 |
+
self.peer_ptrs = [0] * mapping.tp_size
|
84 |
+
self.local_ptr = 0
|
85 |
+
|
86 |
+
def __del__(self):
|
87 |
+
if not sys.is_finalizing() and self.open_ipc:
|
88 |
+
IpcMemory.close_ipc_memory(self.mapping, self.peer_ptrs)
|
89 |
+
|
90 |
+
def serialize(self) -> List[int]:
|
91 |
+
buffer = bytes(0)
|
92 |
+
for ptr in self.peer_ptrs:
|
93 |
+
buffer += struct.pack("P", ptr)
|
94 |
+
|
95 |
+
return array.array("Q", buffer).tolist()
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def open_ipc_memory(mapping: Mapping,
|
99 |
+
size: int,
|
100 |
+
set_to_zero: bool = False) -> Tuple[List[int], int]:
|
101 |
+
""" Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
|
102 |
+
Returns a list of buffer pointers, buffers[i] is a handle to the corresponding buffer residing on GPU #i.
|
103 |
+
Call close_ipc_handle with the *buffer*.
|
104 |
+
"""
|
105 |
+
comm = mpi_comm().Split(mapping.pp_rank, mapping.tp_rank)
|
106 |
+
|
107 |
+
error, local_ptr = cudart.cudaMalloc(size)
|
108 |
+
_raise_if_error(error)
|
109 |
+
if set_to_zero:
|
110 |
+
_raise_if_error(cudart.cudaMemset(local_ptr, 0, size)[0])
|
111 |
+
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
|
112 |
+
_raise_if_error(error)
|
113 |
+
|
114 |
+
handles_reserved = comm.allgather(local_handle.reserved)
|
115 |
+
handles = []
|
116 |
+
for reserved in handles_reserved:
|
117 |
+
handle = cudart.cudaIpcMemHandle_t()
|
118 |
+
handle.reserved = reserved
|
119 |
+
handles.append(handle)
|
120 |
+
|
121 |
+
peer_ptrs = []
|
122 |
+
for node, handle in enumerate(handles):
|
123 |
+
if node == mapping.tp_rank:
|
124 |
+
peer_ptrs.append(local_ptr)
|
125 |
+
else:
|
126 |
+
error, ptr = cudart.cudaIpcOpenMemHandle(
|
127 |
+
handle, cudart.cudaIpcMemLazyEnablePeerAccess)
|
128 |
+
_raise_if_error(error)
|
129 |
+
peer_ptrs.append(ptr)
|
130 |
+
|
131 |
+
return peer_ptrs, local_ptr
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def close_ipc_memory(mapping: Mapping, peer_ptrs: List[int]):
|
135 |
+
for node, ptr in enumerate(peer_ptrs):
|
136 |
+
if node == mapping.tp_rank:
|
137 |
+
_raise_if_error(cudart.cudaFree(ptr)[0])
|
138 |
+
else:
|
139 |
+
_raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])
|
lib.linux-x86_64-cpython-310/tensorrt_llm/_utils.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import copy
|
16 |
+
import gc
|
17 |
+
import inspect
|
18 |
+
import json
|
19 |
+
import math
|
20 |
+
import struct
|
21 |
+
import weakref
|
22 |
+
from dataclasses import asdict
|
23 |
+
from enum import EnumMeta
|
24 |
+
from functools import partial
|
25 |
+
from typing import Any, Dict, List, Optional, Union
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
from packaging import version
|
29 |
+
|
30 |
+
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
31 |
+
|
32 |
+
# isort: off
|
33 |
+
import torch
|
34 |
+
import tensorrt as trt
|
35 |
+
# isort: on
|
36 |
+
|
37 |
+
# numpy doesn't know bfloat16, define abstract binary type instead
|
38 |
+
np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"})
|
39 |
+
np_float8 = np.dtype('V1', metadata={"dtype": "float8"})
|
40 |
+
|
41 |
+
|
42 |
+
def torch_to_numpy(x: torch.Tensor):
|
43 |
+
assert isinstance(x, torch.Tensor), \
|
44 |
+
f'x must be a torch.Tensor object, but got {type(x)}.'
|
45 |
+
if x.dtype == torch.bfloat16:
|
46 |
+
return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16)
|
47 |
+
elif x.dtype == torch.float8_e4m3fn:
|
48 |
+
return x.view(torch.int8).detach().cpu().numpy().view(np_float8)
|
49 |
+
else:
|
50 |
+
return x.detach().cpu().numpy()
|
51 |
+
|
52 |
+
|
53 |
+
def numpy_to_torch(x):
|
54 |
+
if x.dtype == np_bfloat16:
|
55 |
+
return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16)
|
56 |
+
elif x.dtype == np_float8:
|
57 |
+
return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn)
|
58 |
+
else:
|
59 |
+
return torch.from_numpy(x)
|
60 |
+
|
61 |
+
|
62 |
+
def numpy_to_dtype(x, dtype: str):
|
63 |
+
if str_dtype_to_np(dtype) == x.dtype:
|
64 |
+
return x
|
65 |
+
if x.dtype not in [np_bfloat16, np_float8
|
66 |
+
] and dtype not in ['bfloat16', 'fp8']:
|
67 |
+
return x.astype(str_dtype_to_np(dtype))
|
68 |
+
else:
|
69 |
+
return torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype)))
|
70 |
+
|
71 |
+
|
72 |
+
fp32_array = partial(np.array, dtype=np.float32)
|
73 |
+
fp16_array = partial(np.array, dtype=np.float16)
|
74 |
+
int32_array = partial(np.array, dtype=np.int32)
|
75 |
+
int64_array = partial(np.array, dtype=np.int64)
|
76 |
+
bool_array = partial(np.array, dtype=np.bool_)
|
77 |
+
|
78 |
+
|
79 |
+
def dims_array(x):
|
80 |
+
is_int64_dims = True
|
81 |
+
try:
|
82 |
+
trt.Dims([np.iinfo(np.int64).max])
|
83 |
+
except TypeError:
|
84 |
+
is_int64_dims = False
|
85 |
+
return int64_array(x) if is_int64_dims else int32_array(x)
|
86 |
+
|
87 |
+
|
88 |
+
def bf16_array(x):
|
89 |
+
x = torch.tensor(x, dtype=torch.bfloat16)
|
90 |
+
x = torch_to_numpy(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
def numpy_array(data, trt_dtype):
|
95 |
+
# convenient wrapper due to numpy not support bf16 yet
|
96 |
+
if trt_dtype == trt.bfloat16:
|
97 |
+
return bf16_array(data)
|
98 |
+
return np.array(data, trt_dtype_to_np(trt_dtype))
|
99 |
+
|
100 |
+
|
101 |
+
def copy_torch_to_numpy(x: torch.Tensor, ndarray: np.array):
|
102 |
+
if x.dtype == torch.bfloat16:
|
103 |
+
torch.from_numpy(ndarray.view(np.int16)).copy_(x.view(torch.int16))
|
104 |
+
elif x.dtype == torch.float8_e4m3fn:
|
105 |
+
torch.from_numpy(ndarray.view(np.int8)).copy_(x.view(torch.int8))
|
106 |
+
else:
|
107 |
+
torch.from_numpy(ndarray).copy_(x)
|
108 |
+
return ndarray
|
109 |
+
|
110 |
+
|
111 |
+
def trt_version():
|
112 |
+
return trt.__version__
|
113 |
+
|
114 |
+
|
115 |
+
# TRT supports strongly_typed in 9.1
|
116 |
+
def support_strongly_type():
|
117 |
+
return version.parse(trt_version()) >= version.parse("9.1.0")
|
118 |
+
|
119 |
+
|
120 |
+
# Check if TRT version >= 10
|
121 |
+
def trt_gte_10():
|
122 |
+
return version.parse(trt_version()).major > 9
|
123 |
+
|
124 |
+
|
125 |
+
# Check if TRT version >= 10.1
|
126 |
+
def trt_gte_10_1():
|
127 |
+
trt_ver = version.parse(trt_version())
|
128 |
+
return trt_ver.major > 9 and trt_ver.minor > 0
|
129 |
+
|
130 |
+
|
131 |
+
# Check if TRT version >= 10.2
|
132 |
+
def trt_gte_10_2():
|
133 |
+
ver = version.parse(trt_version())
|
134 |
+
return (ver.major * 10 + ver.minor) >= 102
|
135 |
+
|
136 |
+
|
137 |
+
def torch_version():
|
138 |
+
return torch.__version__
|
139 |
+
|
140 |
+
|
141 |
+
_str_to_np_dict = dict(
|
142 |
+
float16=np.float16,
|
143 |
+
float32=np.float32,
|
144 |
+
int64=np.int64,
|
145 |
+
int32=np.int32,
|
146 |
+
int8=np.int8,
|
147 |
+
bool=np.bool_,
|
148 |
+
bfloat16=np_bfloat16,
|
149 |
+
fp8=np_float8,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def str_dtype_to_np(dtype):
|
154 |
+
ret = _str_to_np_dict.get(dtype)
|
155 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
156 |
+
return ret
|
157 |
+
|
158 |
+
|
159 |
+
_str_to_torch_dtype_dict = dict(
|
160 |
+
bfloat16=torch.bfloat16,
|
161 |
+
float16=torch.float16,
|
162 |
+
float32=torch.float32,
|
163 |
+
int64=torch.int64,
|
164 |
+
int32=torch.int32,
|
165 |
+
int8=torch.int8,
|
166 |
+
bool=torch.bool,
|
167 |
+
fp8=torch.float8_e4m3fn,
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def str_dtype_to_torch(dtype):
|
172 |
+
ret = _str_to_torch_dtype_dict.get(dtype)
|
173 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
174 |
+
return ret
|
175 |
+
|
176 |
+
|
177 |
+
_torch_dtype_to_str_dict = {v: k for k, v in _str_to_torch_dtype_dict.items()}
|
178 |
+
|
179 |
+
|
180 |
+
def torch_dtype_to_str(dtype):
|
181 |
+
return _torch_dtype_to_str_dict[dtype]
|
182 |
+
|
183 |
+
|
184 |
+
_str_to_trt_dtype_dict = dict(float16=trt.float16,
|
185 |
+
float32=trt.float32,
|
186 |
+
int64=trt.int64,
|
187 |
+
int32=trt.int32,
|
188 |
+
int8=trt.int8,
|
189 |
+
bool=trt.bool,
|
190 |
+
bfloat16=trt.bfloat16,
|
191 |
+
fp8=trt.fp8)
|
192 |
+
|
193 |
+
|
194 |
+
def str_dtype_to_trt(dtype):
|
195 |
+
ret = _str_to_trt_dtype_dict.get(dtype)
|
196 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
197 |
+
return ret
|
198 |
+
|
199 |
+
|
200 |
+
_trt_to_str_dtype_dict = {v: k for k, v in _str_to_trt_dtype_dict.items()}
|
201 |
+
|
202 |
+
|
203 |
+
def trt_dtype_to_str(dtype: trt.DataType) -> str:
|
204 |
+
assert isinstance(dtype, trt.DataType)
|
205 |
+
return _trt_to_str_dtype_dict[dtype]
|
206 |
+
|
207 |
+
|
208 |
+
_np_to_trt_dtype_dict = {
|
209 |
+
np.int8: trt.int8,
|
210 |
+
np.int32: trt.int32,
|
211 |
+
np.int64: trt.int64,
|
212 |
+
np.float16: trt.float16,
|
213 |
+
np.float32: trt.float32,
|
214 |
+
np.bool_: trt.bool,
|
215 |
+
|
216 |
+
# hash of np.dtype('int32') != np.int32
|
217 |
+
np.dtype('int8'): trt.int8,
|
218 |
+
np.dtype('int32'): trt.int32,
|
219 |
+
np.dtype('int64'): trt.int64,
|
220 |
+
np.dtype('float16'): trt.float16,
|
221 |
+
np.dtype('float32'): trt.float32,
|
222 |
+
np.dtype('bool'): trt.bool,
|
223 |
+
np_bfloat16: trt.bfloat16,
|
224 |
+
np_float8: trt.fp8,
|
225 |
+
}
|
226 |
+
|
227 |
+
|
228 |
+
def np_dtype_to_trt(dtype):
|
229 |
+
ret = _np_to_trt_dtype_dict.get(dtype)
|
230 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
231 |
+
return ret
|
232 |
+
|
233 |
+
|
234 |
+
_trt_to_np_dtype_dict = {
|
235 |
+
trt.int8: np.int8,
|
236 |
+
trt.int32: np.int32,
|
237 |
+
trt.int64: np.int64,
|
238 |
+
trt.float16: np.float16,
|
239 |
+
trt.float32: np.float32,
|
240 |
+
trt.bool: np.bool_,
|
241 |
+
trt.bfloat16: np_bfloat16,
|
242 |
+
trt.fp8: np_float8,
|
243 |
+
}
|
244 |
+
|
245 |
+
|
246 |
+
def trt_dtype_to_np(dtype):
|
247 |
+
ret = _trt_to_np_dtype_dict.get(dtype)
|
248 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
249 |
+
return ret
|
250 |
+
|
251 |
+
|
252 |
+
_torch_to_np_dtype_dict = {
|
253 |
+
torch.bool: np.bool_,
|
254 |
+
torch.uint8: np.uint8,
|
255 |
+
torch.int8: np.int8,
|
256 |
+
torch.int16: np.int16,
|
257 |
+
torch.int32: np.int32,
|
258 |
+
torch.int64: np.int64,
|
259 |
+
torch.float16: np.float16,
|
260 |
+
torch.bfloat16: np_bfloat16,
|
261 |
+
torch.float8_e4m3fn: np_float8,
|
262 |
+
torch.float32: np.float32,
|
263 |
+
torch.float64: np.float64,
|
264 |
+
torch.complex64: np.complex64,
|
265 |
+
torch.complex128: np.complex128,
|
266 |
+
}
|
267 |
+
|
268 |
+
|
269 |
+
def torch_dtype_to_np(dtype):
|
270 |
+
ret = _torch_to_np_dtype_dict.get(dtype)
|
271 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
272 |
+
return ret
|
273 |
+
|
274 |
+
|
275 |
+
_trt_to_torch_dtype_dict = {
|
276 |
+
trt.float16: torch.float16,
|
277 |
+
trt.float32: torch.float32,
|
278 |
+
trt.int64: torch.int64,
|
279 |
+
trt.int32: torch.int32,
|
280 |
+
trt.int8: torch.int8,
|
281 |
+
trt.bool: torch.bool,
|
282 |
+
trt.bfloat16: torch.bfloat16,
|
283 |
+
trt.fp8: torch.float8_e4m3fn,
|
284 |
+
}
|
285 |
+
|
286 |
+
|
287 |
+
def trt_dtype_to_torch(dtype):
|
288 |
+
ret = _trt_to_torch_dtype_dict.get(dtype)
|
289 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
290 |
+
return ret
|
291 |
+
|
292 |
+
|
293 |
+
def is_same_dtype(type_a: Union[str, trt.DataType],
|
294 |
+
type_b: Union[str, trt.DataType]) -> bool:
|
295 |
+
if isinstance(type_a, str):
|
296 |
+
type_a = str_dtype_to_trt(type_a)
|
297 |
+
|
298 |
+
if isinstance(type_b, str):
|
299 |
+
type_b = str_dtype_to_trt(type_b)
|
300 |
+
|
301 |
+
return type_a == type_b
|
302 |
+
|
303 |
+
|
304 |
+
_torch_to_trt_dtype_dict = {
|
305 |
+
torch.float16: trt.float16,
|
306 |
+
torch.float32: trt.float32,
|
307 |
+
torch.int64: trt.int64,
|
308 |
+
torch.int32: trt.int32,
|
309 |
+
torch.int8: trt.int8,
|
310 |
+
torch.float8_e4m3fn: trt.fp8,
|
311 |
+
torch.qint8: trt.int8,
|
312 |
+
torch.bool: trt.bool,
|
313 |
+
torch.bfloat16: trt.bfloat16
|
314 |
+
}
|
315 |
+
|
316 |
+
|
317 |
+
def torch_dtype_to_trt(dtype):
|
318 |
+
ret = _torch_to_trt_dtype_dict.get(dtype)
|
319 |
+
assert ret is not None, f'Unsupported dtype: {dtype}'
|
320 |
+
return ret
|
321 |
+
|
322 |
+
|
323 |
+
def dim_to_trt_axes(dim):
|
324 |
+
"""Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
|
325 |
+
if not isinstance(dim, tuple):
|
326 |
+
dim = (dim, )
|
327 |
+
|
328 |
+
# create axes bitmask for reduce layer
|
329 |
+
axes = 0
|
330 |
+
for d in dim:
|
331 |
+
axes |= 1 << d
|
332 |
+
|
333 |
+
return axes
|
334 |
+
|
335 |
+
|
336 |
+
def trt_axes_to_dim(axes: int) -> List[int]:
|
337 |
+
"""Converts tensorrt axes bitmask to dims"""
|
338 |
+
dim = []
|
339 |
+
for i in range(32):
|
340 |
+
if axes & (1 << i):
|
341 |
+
dim.append(i)
|
342 |
+
|
343 |
+
return dim
|
344 |
+
|
345 |
+
|
346 |
+
def dim_resolve_negative(dim, ndim):
|
347 |
+
if not isinstance(dim, tuple):
|
348 |
+
dim = (dim, )
|
349 |
+
pos = []
|
350 |
+
for d in dim:
|
351 |
+
if d < 0:
|
352 |
+
d = ndim + d
|
353 |
+
pos.append(d)
|
354 |
+
return tuple(pos)
|
355 |
+
|
356 |
+
|
357 |
+
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
|
358 |
+
OMPI_COMM_TYPE_HOST = 9
|
359 |
+
|
360 |
+
|
361 |
+
def mpi_comm():
|
362 |
+
from mpi4py import MPI
|
363 |
+
return MPI.COMM_WORLD
|
364 |
+
|
365 |
+
|
366 |
+
def mpi_rank():
|
367 |
+
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
|
368 |
+
|
369 |
+
|
370 |
+
def mpi_world_size():
|
371 |
+
return mpi_comm().Get_size() if ENABLE_MULTI_DEVICE else 1
|
372 |
+
|
373 |
+
|
374 |
+
def mpi_barrier():
|
375 |
+
mpi_comm().Barrier()
|
376 |
+
|
377 |
+
|
378 |
+
def mpi_broadcast(obj, root=0):
|
379 |
+
return mpi_comm().bcast(obj, root)
|
380 |
+
|
381 |
+
|
382 |
+
def pad_vocab_size(vocab_size, tp_size):
|
383 |
+
return int(math.ceil(vocab_size / tp_size) * tp_size)
|
384 |
+
|
385 |
+
|
386 |
+
def to_dict(obj):
|
387 |
+
return copy.deepcopy(obj.__dict__)
|
388 |
+
|
389 |
+
|
390 |
+
def to_json_string(obj):
|
391 |
+
if not isinstance(obj, dict):
|
392 |
+
obj = to_dict(obj)
|
393 |
+
return json.dumps(obj, indent=2, sort_keys=True) + "\n"
|
394 |
+
|
395 |
+
|
396 |
+
def to_json_file(obj, json_file_path):
|
397 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
398 |
+
writer.write(to_json_string(obj))
|
399 |
+
|
400 |
+
|
401 |
+
def numpy_fp32_to_bf16(src):
|
402 |
+
# Numpy doesn't support bfloat16 type
|
403 |
+
# Convert float32 to bfloat16 manually and assign with bf16 abstract type
|
404 |
+
original_shape = src.shape
|
405 |
+
src = src.flatten()
|
406 |
+
src = np.ascontiguousarray(src)
|
407 |
+
|
408 |
+
assert src.dtype == np.float32
|
409 |
+
dst = np.empty_like(src, dtype=np.uint16)
|
410 |
+
for i in range(len(dst)):
|
411 |
+
bytes = struct.pack('<f', src[i])
|
412 |
+
dst[i] = struct.unpack('<H', struct.pack('BB', bytes[2], bytes[3]))[0]
|
413 |
+
return dst.reshape(original_shape).view(np_bfloat16)
|
414 |
+
|
415 |
+
|
416 |
+
_extra_attrs_by_object: Dict[int, Dict[str, Any]] = {}
|
417 |
+
|
418 |
+
|
419 |
+
def get_extra_attr(obj, attr_name):
|
420 |
+
if id(obj) not in _extra_attrs_by_object:
|
421 |
+
return None
|
422 |
+
extra_attrs = _extra_attrs_by_object[id(obj)]
|
423 |
+
return extra_attrs.get(attr_name)
|
424 |
+
|
425 |
+
|
426 |
+
def _clean_extra_attrs(obj_id):
|
427 |
+
if obj_id in _extra_attrs_by_object:
|
428 |
+
del _extra_attrs_by_object[obj_id]
|
429 |
+
|
430 |
+
|
431 |
+
def set_extra_attr(obj, attr_name, value):
|
432 |
+
if id(obj) not in _extra_attrs_by_object:
|
433 |
+
_extra_attrs_by_object[id(obj)] = {}
|
434 |
+
weakref.finalize(obj, _clean_extra_attrs, id(obj))
|
435 |
+
_extra_attrs_by_object[id(obj)][attr_name] = value
|
436 |
+
|
437 |
+
|
438 |
+
def has_extra_attr(obj, attr_name):
|
439 |
+
if id(obj) not in _extra_attrs_by_object:
|
440 |
+
return False
|
441 |
+
return attr_name in _extra_attrs_by_object[id(obj)]
|
442 |
+
|
443 |
+
|
444 |
+
def set_obj_attrs(
|
445 |
+
obj: torch.Tensor,
|
446 |
+
ojb_attrs: Optional[Dict[str, Any]],
|
447 |
+
):
|
448 |
+
"""Set attributes on a object.
|
449 |
+
|
450 |
+
This method is used to set attributes on a object. This method
|
451 |
+
will not overwrite existing attributes.
|
452 |
+
"""
|
453 |
+
if ojb_attrs is None:
|
454 |
+
return
|
455 |
+
for key, value in ojb_attrs.items():
|
456 |
+
assert not hasattr(
|
457 |
+
obj, key), (f"Overwriting existing tensor attribute: {key}")
|
458 |
+
setattr(obj, key, value)
|
459 |
+
|
460 |
+
|
461 |
+
def get_init_params(obj, cls=None):
|
462 |
+
"""
|
463 |
+
Get all parameters in object's __init__.
|
464 |
+
Use cls's __init__ as filter if cls provided.
|
465 |
+
"""
|
466 |
+
names = None
|
467 |
+
if cls is not None:
|
468 |
+
names = set(list(inspect.signature(cls.__init__).parameters)[1:])
|
469 |
+
return {
|
470 |
+
name: getattr(obj, name)
|
471 |
+
for name in list(inspect.signature(obj.__class__.__init__).parameters)
|
472 |
+
[1:] if names is None or name in names
|
473 |
+
}
|
474 |
+
|
475 |
+
|
476 |
+
def release_gc():
|
477 |
+
''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
|
478 |
+
This could be used when some states might be kept in memory even after the variables are deleted.
|
479 |
+
'''
|
480 |
+
gc.collect()
|
481 |
+
if torch.cuda.is_available():
|
482 |
+
torch.cuda.empty_cache()
|
483 |
+
torch.cuda.ipc_collect()
|
484 |
+
|
485 |
+
|
486 |
+
class DictConversion:
|
487 |
+
|
488 |
+
@classmethod
|
489 |
+
def from_dict(cls, config: Dict[str, Any]):
|
490 |
+
obj = cls()
|
491 |
+
fields = obj.__dataclass_fields__
|
492 |
+
for key, value in config.items():
|
493 |
+
assert hasattr(obj, key)
|
494 |
+
field_cls = fields[key].type
|
495 |
+
if (isinstance(field_cls, type)
|
496 |
+
and issubclass(field_cls, DictConversion)
|
497 |
+
and isinstance(value, dict)):
|
498 |
+
value = field_cls.from_dict(value)
|
499 |
+
setattr(obj, key, value)
|
500 |
+
return obj
|
501 |
+
|
502 |
+
def to_dict(self):
|
503 |
+
return asdict(self)
|
504 |
+
|
505 |
+
@classmethod
|
506 |
+
def from_json_file(cls, file):
|
507 |
+
with open(file) as f:
|
508 |
+
return cls.from_dict(json.load(f))
|
509 |
+
|
510 |
+
def set_defaults(self, **kwargs):
|
511 |
+
for key, default in kwargs.items():
|
512 |
+
value = getattr(self, key)
|
513 |
+
if (value is None
|
514 |
+
or (isinstance(value, (list, dict)) and len(value) == 0)):
|
515 |
+
setattr(self, key, default)
|
516 |
+
|
517 |
+
|
518 |
+
class BaseEnumMeta(EnumMeta):
|
519 |
+
|
520 |
+
def __contains__(cls, item):
|
521 |
+
try:
|
522 |
+
cls(item)
|
523 |
+
except ValueError:
|
524 |
+
return False
|
525 |
+
return True
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .auto_parallel import auto_parallel
|
2 |
+
from .cluster_info import infer_cluster_config
|
3 |
+
from .config import AutoParallelConfig
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'auto_parallel',
|
7 |
+
'AutoParallelConfig',
|
8 |
+
'infer_cluster_config',
|
9 |
+
]
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/auto_parallel.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import tensorrt as trt
|
7 |
+
import torch
|
8 |
+
from filelock import FileLock
|
9 |
+
|
10 |
+
from tensorrt_llm.functional import DimRange, Tensor
|
11 |
+
from tensorrt_llm.logger import logger
|
12 |
+
from tensorrt_llm.network import Network, net_guard
|
13 |
+
|
14 |
+
from .config import AutoParallelConfig
|
15 |
+
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
|
16 |
+
from .node_graph import NodeGraph
|
17 |
+
from .parallelization import ParallelConfig, parallelize
|
18 |
+
from .pipeline_graph import PipelineGraph
|
19 |
+
from .simplifier import GraphConfig, Simplifier, StageType
|
20 |
+
from .utils import current_flags
|
21 |
+
|
22 |
+
|
23 |
+
def to_network(graph: PipelineGraph, network: Network):
|
24 |
+
logger.debug("Converting graph to network")
|
25 |
+
trt_network = graph.as_trt()
|
26 |
+
trt_network.name = network.trt_network.name
|
27 |
+
new_network = Network()
|
28 |
+
new_network._init(trt_network)
|
29 |
+
new_network._dtype = network._dtype
|
30 |
+
new_network._plugin_config = network._plugin_config
|
31 |
+
new_network._unfilled_weights = graph._unfilled_weights
|
32 |
+
new_network._auto_parallel_config = graph._auto_parallel_config
|
33 |
+
with net_guard(network):
|
34 |
+
for i in range(trt_network.num_inputs):
|
35 |
+
input = trt_network.get_input(i)
|
36 |
+
tensor = Tensor(is_network_input=False)
|
37 |
+
if input.name in network._inputs:
|
38 |
+
profiles = network._inputs[input.name].profiles
|
39 |
+
elif len(network._inputs) == 0:
|
40 |
+
profiles = []
|
41 |
+
else:
|
42 |
+
shape = input.shape
|
43 |
+
num_profiles = len(list(network._inputs.values())[0].profiles)
|
44 |
+
profile = DimRange(shape, [None] * len(shape))
|
45 |
+
profiles = [profile] * num_profiles
|
46 |
+
tensor.profiles = profiles
|
47 |
+
tensor.trt_tensor = input
|
48 |
+
new_network._inputs[input.name] = tensor
|
49 |
+
return new_network
|
50 |
+
|
51 |
+
|
52 |
+
def find_solution(
|
53 |
+
node_graph: NodeGraph,
|
54 |
+
graph_config: GraphConfig,
|
55 |
+
lmesh: LogicalDeviceMesh,
|
56 |
+
memory_budget: int,
|
57 |
+
flags: list,
|
58 |
+
device: int,
|
59 |
+
dump_path: str,
|
60 |
+
) -> ParallelConfig:
|
61 |
+
torch.cuda.set_device(device)
|
62 |
+
with current_flags(*flags):
|
63 |
+
cost_graph = node_graph.get_cost_graph(lmesh)
|
64 |
+
num_stages = graph_config.num_stages
|
65 |
+
if num_stages == 1:
|
66 |
+
stage_types = [None]
|
67 |
+
elif num_stages == 2:
|
68 |
+
stage_types = [StageType.START, StageType.END]
|
69 |
+
else:
|
70 |
+
stage_types = [StageType.START, StageType.BLOCK, StageType.END]
|
71 |
+
|
72 |
+
best_config, best_solution = None, None
|
73 |
+
for stage_type in stage_types:
|
74 |
+
if stage_type is not None:
|
75 |
+
node_graph.set_slowest_stage(stage_type, graph_config)
|
76 |
+
solution = node_graph.find_solution(
|
77 |
+
cost_graph,
|
78 |
+
memory_budget,
|
79 |
+
)
|
80 |
+
cost = solution.total_cost
|
81 |
+
if best_config is None or cost < best_config.cost:
|
82 |
+
best_config = ParallelConfig()
|
83 |
+
best_config.graph_config = graph_config
|
84 |
+
best_config.lmesh = lmesh
|
85 |
+
best_config.cost = cost
|
86 |
+
best_config.graph_strategy = solution.node_best_strategy
|
87 |
+
best_config.stage_type = stage_type
|
88 |
+
best_solution = solution
|
89 |
+
if dump_path is not None:
|
90 |
+
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
|
91 |
+
vlz_name = f"{dump_path}/solution."
|
92 |
+
if graph_config.num_micro_batches != 1:
|
93 |
+
vlz_name += f"mbs{graph_config.num_micro_batches}."
|
94 |
+
if graph_config.num_stages != 1:
|
95 |
+
vlz_name += f"stages{graph_config.num_stages}."
|
96 |
+
vlz_name += lmesh.cluster_key
|
97 |
+
with lock:
|
98 |
+
node_graph.visualize_solution(
|
99 |
+
best_solution,
|
100 |
+
vlz_name,
|
101 |
+
ignore_shape_io=True,
|
102 |
+
)
|
103 |
+
return best_config
|
104 |
+
|
105 |
+
|
106 |
+
def infer_builder_flags(network):
|
107 |
+
fp16_enabled = False
|
108 |
+
bf16_enabled = False
|
109 |
+
int8_enabled = False
|
110 |
+
fp8_enabled = False
|
111 |
+
|
112 |
+
def check_dtype(tensor):
|
113 |
+
nonlocal fp16_enabled
|
114 |
+
nonlocal bf16_enabled
|
115 |
+
nonlocal int8_enabled
|
116 |
+
nonlocal fp8_enabled
|
117 |
+
if tensor.dtype == trt.DataType.HALF:
|
118 |
+
fp16_enabled = True
|
119 |
+
elif tensor.dtype == trt.DataType.BF16:
|
120 |
+
bf16_enabled = True
|
121 |
+
elif tensor.dtype == trt.DataType.INT8:
|
122 |
+
int8_enabled = True
|
123 |
+
elif tensor.dtype == trt.DataType.FP8:
|
124 |
+
fp8_enabled = True
|
125 |
+
|
126 |
+
trt_network = network.trt_network
|
127 |
+
for i in range(trt_network.num_inputs):
|
128 |
+
input = trt_network.get_input(i)
|
129 |
+
check_dtype(input)
|
130 |
+
for i in range(trt_network.num_layers):
|
131 |
+
layer = trt_network.get_layer(i)
|
132 |
+
for j in range(layer.num_outputs):
|
133 |
+
output = layer.get_output(j)
|
134 |
+
check_dtype(output)
|
135 |
+
|
136 |
+
builder_flags = 0
|
137 |
+
if fp16_enabled:
|
138 |
+
builder_flags |= 1 << int(trt.BuilderFlag.FP16)
|
139 |
+
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
140 |
+
if bf16_enabled:
|
141 |
+
builder_flags |= 1 << int(trt.BuilderFlag.BF16)
|
142 |
+
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
143 |
+
if int8_enabled:
|
144 |
+
builder_flags |= 1 << int(trt.BuilderFlag.INT8)
|
145 |
+
if fp8_enabled:
|
146 |
+
builder_flags |= 1 << int(trt.BuilderFlag.FP8)
|
147 |
+
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
148 |
+
return builder_flags
|
149 |
+
|
150 |
+
|
151 |
+
def auto_parallel(network: Network, config: AutoParallelConfig):
|
152 |
+
debug_mode = config.debug_mode
|
153 |
+
memory_budget = config.get_cluster_info(
|
154 |
+
).memory_budget_per_device * 1024 * 1024 * 1024
|
155 |
+
enable_pipeline_parallelism = config.enable_pipeline_parallelism
|
156 |
+
if config.world_size < config.gpus_per_node:
|
157 |
+
num_hosts = 1
|
158 |
+
num_devices_per_host = config.world_size
|
159 |
+
else:
|
160 |
+
assert config.world_size % config.gpus_per_node == 0
|
161 |
+
num_hosts = config.world_size // config.gpus_per_node
|
162 |
+
num_devices_per_host = config.gpus_per_node
|
163 |
+
parallel_config_cache = config.parallel_config_cache
|
164 |
+
dump_path = config.dump_path if debug_mode else None
|
165 |
+
fill_weights = config.fill_weights
|
166 |
+
|
167 |
+
if num_hosts == 1 and num_devices_per_host == 1:
|
168 |
+
return [network]
|
169 |
+
|
170 |
+
if dump_path is not None:
|
171 |
+
if not os.path.exists(dump_path):
|
172 |
+
os.makedirs(dump_path)
|
173 |
+
|
174 |
+
builder_flags = config.builder_flags or infer_builder_flags(network)
|
175 |
+
flags = [builder_flags, network.strongly_typed]
|
176 |
+
with current_flags(*flags):
|
177 |
+
simplifier = Simplifier(network, config)
|
178 |
+
network_hash = simplifier.get_network_hash()
|
179 |
+
|
180 |
+
best_config = None
|
181 |
+
if parallel_config_cache is not None and Path(
|
182 |
+
parallel_config_cache).exists():
|
183 |
+
parallel_config = ParallelConfig.from_file(parallel_config_cache)
|
184 |
+
if (ParallelConfig.VERSION == parallel_config.version
|
185 |
+
and network_hash == parallel_config.network_hash
|
186 |
+
and config == parallel_config.auto_parallel_config):
|
187 |
+
logger.info(
|
188 |
+
f"use cache of parallel config from {parallel_config_cache}"
|
189 |
+
)
|
190 |
+
best_config = parallel_config
|
191 |
+
|
192 |
+
if best_config is None:
|
193 |
+
num_devices = num_hosts * num_devices_per_host
|
194 |
+
phy_ids = [[
|
195 |
+
i + j * num_devices_per_host
|
196 |
+
for i in range(num_devices_per_host)
|
197 |
+
] for j in range(num_hosts)]
|
198 |
+
phy_mesh = PhysicalDeviceMesh(phy_ids, config)
|
199 |
+
if enable_pipeline_parallelism:
|
200 |
+
num_micro_batches_list = simplifier.list_all_num_micro_batches()
|
201 |
+
else:
|
202 |
+
num_micro_batches_list = [1]
|
203 |
+
|
204 |
+
jobs = []
|
205 |
+
for num_micro_batches in num_micro_batches_list:
|
206 |
+
simplifier.infer_shapes(num_micro_batches)
|
207 |
+
if enable_pipeline_parallelism:
|
208 |
+
pipeline_configs = phy_mesh.list_all_pipeline_configs()
|
209 |
+
else:
|
210 |
+
pipeline_configs = [(1, num_devices)]
|
211 |
+
for num_stages, num_devices_per_stage in pipeline_configs:
|
212 |
+
# TODO: add fallback path that allows num_micro_batches >= num_stages
|
213 |
+
# if no solution satisfies memory budget
|
214 |
+
if num_micro_batches < num_stages:
|
215 |
+
continue
|
216 |
+
simplified_graph, graph_config = simplifier.simplify_graph(
|
217 |
+
phy_mesh,
|
218 |
+
num_stages,
|
219 |
+
num_devices_per_stage,
|
220 |
+
)
|
221 |
+
if simplified_graph is None:
|
222 |
+
continue
|
223 |
+
node_graph = NodeGraph(simplified_graph)
|
224 |
+
node_graph.assign_cost_weights(graph_config)
|
225 |
+
lmeshes = graph_config.stage_phy_meshes[
|
226 |
+
0].get_logical_meshes()
|
227 |
+
for lmesh in lmeshes:
|
228 |
+
jobs.append(
|
229 |
+
(node_graph, graph_config, lmesh, memory_budget *
|
230 |
+
(num_devices / num_devices_per_stage)))
|
231 |
+
|
232 |
+
try:
|
233 |
+
with ThreadPoolExecutor() as executor:
|
234 |
+
best_config = sorted(
|
235 |
+
executor.map(
|
236 |
+
lambda x: find_solution(
|
237 |
+
*x,
|
238 |
+
flags,
|
239 |
+
torch.cuda.current_device(),
|
240 |
+
dump_path,
|
241 |
+
),
|
242 |
+
jobs,
|
243 |
+
),
|
244 |
+
key=lambda x: x.cost,
|
245 |
+
)[0]
|
246 |
+
finally:
|
247 |
+
phy_mesh.close()
|
248 |
+
|
249 |
+
if parallel_config_cache is not None:
|
250 |
+
best_config.network_hash = network_hash
|
251 |
+
best_config.auto_parallel_config = config
|
252 |
+
best_config.save(parallel_config_cache)
|
253 |
+
|
254 |
+
new_graphs = parallelize(simplifier, best_config)
|
255 |
+
|
256 |
+
networks = [to_network(new_graph, network) for new_graph in new_graphs]
|
257 |
+
if debug_mode and fill_weights:
|
258 |
+
networks[0]._fill_weights()
|
259 |
+
|
260 |
+
gc.collect()
|
261 |
+
torch.cuda.empty_cache()
|
262 |
+
|
263 |
+
return networks
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/cluster_info.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import re
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, Tuple, Union
|
5 |
+
|
6 |
+
import pynvml
|
7 |
+
import torch
|
8 |
+
from cuda import cudart
|
9 |
+
|
10 |
+
from tensorrt_llm._utils import DictConversion
|
11 |
+
from tensorrt_llm.logger import logger
|
12 |
+
from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class MathThroughput(DictConversion):
|
17 |
+
int4: int = 0 # Tflops
|
18 |
+
int8: int = 0 # Tflops
|
19 |
+
fp8: int = 0 # Tflops
|
20 |
+
float16: int = 0 # Tflops
|
21 |
+
bfloat16: int = 0 # Tflops
|
22 |
+
float32: int = 0 # Tflops
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def to_tflops(
|
26 |
+
ipc_per_sm: "MathThroughput",
|
27 |
+
sm_count: int,
|
28 |
+
clock_mhz: int,
|
29 |
+
) -> "MathThroughput":
|
30 |
+
tflops = MathThroughput()
|
31 |
+
for name in ipc_per_sm.__dataclass_fields__:
|
32 |
+
setattr(
|
33 |
+
tflops, name,
|
34 |
+
getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6))
|
35 |
+
return tflops
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class ClusterInfo(DictConversion):
|
40 |
+
inter_node_bw_per_device: int = 25 # GBps
|
41 |
+
intra_node_bw_per_device: int = 0 # GBps
|
42 |
+
inter_node_latency: int = 10 # us
|
43 |
+
intra_node_latency: int = 10 # us
|
44 |
+
intra_node_sharp: bool = False
|
45 |
+
inter_node_sharp: bool = True
|
46 |
+
|
47 |
+
memory_bw: int = 0 # GBps
|
48 |
+
memory_budget_per_device: int = 0 # GB
|
49 |
+
|
50 |
+
math_throughput: MathThroughput = field(default_factory=MathThroughput)
|
51 |
+
|
52 |
+
memory_efficiency: float = 1.0
|
53 |
+
math_efficiency: float = 1.0
|
54 |
+
communication_efficiency: float = 1.0
|
55 |
+
|
56 |
+
|
57 |
+
_math_throughputs = {
|
58 |
+
"A100": MathThroughput(
|
59 |
+
int8=624,
|
60 |
+
float16=312,
|
61 |
+
bfloat16=312,
|
62 |
+
float32=156,
|
63 |
+
),
|
64 |
+
}
|
65 |
+
|
66 |
+
_bandwidths = {
|
67 |
+
"PCIe-3": 16,
|
68 |
+
"PCIe-4": 32,
|
69 |
+
"PCIe-5": 64,
|
70 |
+
}
|
71 |
+
|
72 |
+
cluster_infos = {
|
73 |
+
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
74 |
+
"A100-SXM-80GB":
|
75 |
+
ClusterInfo(
|
76 |
+
intra_node_bw_per_device=300,
|
77 |
+
memory_bw=2039,
|
78 |
+
memory_budget_per_device=80,
|
79 |
+
math_throughput=_math_throughputs["A100"],
|
80 |
+
),
|
81 |
+
"A100-SXM-40GB":
|
82 |
+
ClusterInfo(
|
83 |
+
intra_node_bw_per_device=300,
|
84 |
+
memory_bw=1555,
|
85 |
+
memory_budget_per_device=40,
|
86 |
+
math_throughput=_math_throughputs["A100"],
|
87 |
+
),
|
88 |
+
"A100-PCIe-80GB":
|
89 |
+
ClusterInfo(
|
90 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
91 |
+
memory_bw=1935,
|
92 |
+
memory_budget_per_device=80,
|
93 |
+
math_throughput=_math_throughputs["A100"],
|
94 |
+
),
|
95 |
+
"A100-PCIe-40GB":
|
96 |
+
ClusterInfo(
|
97 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
98 |
+
memory_bw=1555,
|
99 |
+
memory_budget_per_device=40,
|
100 |
+
math_throughput=_math_throughputs["A100"],
|
101 |
+
),
|
102 |
+
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
|
103 |
+
"H100-SXM":
|
104 |
+
ClusterInfo(
|
105 |
+
inter_node_bw_per_device=50,
|
106 |
+
intra_node_bw_per_device=450,
|
107 |
+
intra_node_sharp=True,
|
108 |
+
memory_bw=3350,
|
109 |
+
memory_budget_per_device=80,
|
110 |
+
math_throughput=MathThroughput(
|
111 |
+
int8=1979,
|
112 |
+
fp8=1979,
|
113 |
+
float16=989,
|
114 |
+
bfloat16=989,
|
115 |
+
float32=495,
|
116 |
+
),
|
117 |
+
),
|
118 |
+
"H100-PCIe":
|
119 |
+
ClusterInfo(
|
120 |
+
inter_node_bw_per_device=50,
|
121 |
+
intra_node_bw_per_device=_bandwidths["PCIe-5"],
|
122 |
+
memory_bw=2000,
|
123 |
+
memory_budget_per_device=80,
|
124 |
+
math_throughput=MathThroughput(
|
125 |
+
int8=1513,
|
126 |
+
fp8=1513,
|
127 |
+
float16=756,
|
128 |
+
bfloat16=756,
|
129 |
+
float32=378,
|
130 |
+
),
|
131 |
+
),
|
132 |
+
"H20":
|
133 |
+
ClusterInfo(
|
134 |
+
inter_node_bw_per_device=50,
|
135 |
+
intra_node_bw_per_device=450,
|
136 |
+
memory_bw=4000,
|
137 |
+
memory_budget_per_device=96,
|
138 |
+
math_throughput=MathThroughput(
|
139 |
+
int8=293,
|
140 |
+
fp8=293,
|
141 |
+
float16=147,
|
142 |
+
bfloat16=147,
|
143 |
+
float32=74,
|
144 |
+
),
|
145 |
+
),
|
146 |
+
# from https://images.nvidia.cn/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
|
147 |
+
"V100-PCIe-16GB":
|
148 |
+
ClusterInfo(
|
149 |
+
intra_node_bw_per_device=_bandwidths["PCIe-3"],
|
150 |
+
memory_bw=900,
|
151 |
+
memory_budget_per_device=16,
|
152 |
+
math_throughput=MathThroughput(float32=112),
|
153 |
+
),
|
154 |
+
"V100-PCIe-32GB":
|
155 |
+
ClusterInfo(
|
156 |
+
intra_node_bw_per_device=_bandwidths["PCIe-3"],
|
157 |
+
memory_bw=900,
|
158 |
+
memory_budget_per_device=32,
|
159 |
+
math_throughput=MathThroughput(float32=112),
|
160 |
+
),
|
161 |
+
"V100-SXM-16GB":
|
162 |
+
ClusterInfo(
|
163 |
+
intra_node_bw_per_device=150,
|
164 |
+
memory_bw=900,
|
165 |
+
memory_budget_per_device=16,
|
166 |
+
math_throughput=MathThroughput(float32=125),
|
167 |
+
),
|
168 |
+
"V100-SXM-32GB":
|
169 |
+
ClusterInfo(
|
170 |
+
intra_node_bw_per_device=150,
|
171 |
+
memory_bw=900,
|
172 |
+
memory_budget_per_device=32,
|
173 |
+
math_throughput=MathThroughput(float32=125),
|
174 |
+
),
|
175 |
+
"V100S-PCIe":
|
176 |
+
ClusterInfo(
|
177 |
+
intra_node_bw_per_device=_bandwidths["PCIe-3"],
|
178 |
+
memory_bw=1134,
|
179 |
+
memory_budget_per_device=32,
|
180 |
+
math_throughput=MathThroughput(float32=130),
|
181 |
+
),
|
182 |
+
# from https://images.nvidia.cn/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
|
183 |
+
"A40":
|
184 |
+
ClusterInfo(
|
185 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
186 |
+
memory_bw=696,
|
187 |
+
memory_budget_per_device=48,
|
188 |
+
math_throughput=MathThroughput(
|
189 |
+
int4=600,
|
190 |
+
int8=300,
|
191 |
+
float16=150,
|
192 |
+
bfloat16=150,
|
193 |
+
float32=75,
|
194 |
+
),
|
195 |
+
),
|
196 |
+
# from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf
|
197 |
+
"A30":
|
198 |
+
ClusterInfo(
|
199 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
200 |
+
memory_bw=933,
|
201 |
+
memory_budget_per_device=24,
|
202 |
+
math_throughput=MathThroughput(
|
203 |
+
int4=661,
|
204 |
+
int8=330,
|
205 |
+
float16=165,
|
206 |
+
bfloat16=165,
|
207 |
+
float32=82,
|
208 |
+
),
|
209 |
+
),
|
210 |
+
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf
|
211 |
+
"A10":
|
212 |
+
ClusterInfo(
|
213 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
214 |
+
memory_bw=600,
|
215 |
+
memory_budget_per_device=24,
|
216 |
+
math_throughput=MathThroughput(
|
217 |
+
int4=500,
|
218 |
+
int8=250,
|
219 |
+
float16=125,
|
220 |
+
bfloat16=125,
|
221 |
+
float32=62.5,
|
222 |
+
),
|
223 |
+
),
|
224 |
+
"A10G":
|
225 |
+
ClusterInfo(
|
226 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
227 |
+
memory_bw=600,
|
228 |
+
memory_budget_per_device=24,
|
229 |
+
math_throughput=MathThroughput(
|
230 |
+
int4=280,
|
231 |
+
int8=140,
|
232 |
+
float16=70,
|
233 |
+
bfloat16=70,
|
234 |
+
float32=35,
|
235 |
+
),
|
236 |
+
),
|
237 |
+
# from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413
|
238 |
+
"L40S":
|
239 |
+
ClusterInfo(
|
240 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
241 |
+
memory_bw=864,
|
242 |
+
memory_budget_per_device=48,
|
243 |
+
math_throughput=MathThroughput(
|
244 |
+
int4=733,
|
245 |
+
int8=733,
|
246 |
+
fp8=733,
|
247 |
+
float16=362,
|
248 |
+
bfloat16=362,
|
249 |
+
float32=183,
|
250 |
+
),
|
251 |
+
),
|
252 |
+
# from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf
|
253 |
+
"L40":
|
254 |
+
ClusterInfo(
|
255 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
256 |
+
memory_bw=864,
|
257 |
+
memory_budget_per_device=48,
|
258 |
+
math_throughput=MathThroughput(
|
259 |
+
int4=724,
|
260 |
+
int8=362,
|
261 |
+
fp8=362,
|
262 |
+
float16=181,
|
263 |
+
bfloat16=181,
|
264 |
+
float32=90,
|
265 |
+
),
|
266 |
+
),
|
267 |
+
"L20":
|
268 |
+
ClusterInfo(
|
269 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
270 |
+
memory_bw=864,
|
271 |
+
memory_budget_per_device=48,
|
272 |
+
math_throughput=MathThroughput(
|
273 |
+
int8=238,
|
274 |
+
fp8=238,
|
275 |
+
float16=119,
|
276 |
+
bfloat16=119,
|
277 |
+
float32=60,
|
278 |
+
),
|
279 |
+
),
|
280 |
+
# from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652
|
281 |
+
"L4":
|
282 |
+
ClusterInfo(
|
283 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
284 |
+
memory_bw=300,
|
285 |
+
memory_budget_per_device=24,
|
286 |
+
math_throughput=MathThroughput(
|
287 |
+
int8=242,
|
288 |
+
fp8=242,
|
289 |
+
float16=120,
|
290 |
+
bfloat16=120,
|
291 |
+
float32=60,
|
292 |
+
),
|
293 |
+
),
|
294 |
+
"L2":
|
295 |
+
ClusterInfo(
|
296 |
+
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
297 |
+
memory_bw=300,
|
298 |
+
memory_budget_per_device=24,
|
299 |
+
math_throughput=MathThroughput(
|
300 |
+
int8=193,
|
301 |
+
fp8=193,
|
302 |
+
float16=97,
|
303 |
+
bfloat16=97,
|
304 |
+
float32=48,
|
305 |
+
),
|
306 |
+
),
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
def infer_cluster_key() -> str:
|
311 |
+
|
312 |
+
def match(product, name):
|
313 |
+
# Use A100 as example, the regex pattern matches for:
|
314 |
+
# - NVIDIA A100 80GB
|
315 |
+
# - NVIDIA A100-PCIE
|
316 |
+
# - NVIDIA A100
|
317 |
+
# And does not match A1000 etc.
|
318 |
+
return re.match(f".*{product}([ -]|$).*", name) is not None
|
319 |
+
|
320 |
+
def is_sxm():
|
321 |
+
return "SXM" in device_name
|
322 |
+
|
323 |
+
def is_80gb():
|
324 |
+
return "80GB" in device_name
|
325 |
+
|
326 |
+
def is_32gb():
|
327 |
+
return "32GB" in device_name
|
328 |
+
|
329 |
+
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
330 |
+
|
331 |
+
if match("A100", device_name):
|
332 |
+
if is_sxm():
|
333 |
+
if is_80gb():
|
334 |
+
return "A100-SXM-80GB"
|
335 |
+
else:
|
336 |
+
return "A100-SXM-40GB"
|
337 |
+
else:
|
338 |
+
if is_80gb():
|
339 |
+
return "A100-PCIe-80GB"
|
340 |
+
else:
|
341 |
+
return "A100-PCIe-40GB"
|
342 |
+
elif match("A10G", device_name):
|
343 |
+
return "A10G"
|
344 |
+
elif match("A10", device_name):
|
345 |
+
return "A10"
|
346 |
+
elif match("A30", device_name):
|
347 |
+
return "A30"
|
348 |
+
elif match("A40", device_name):
|
349 |
+
return "A40"
|
350 |
+
elif match("H100", device_name):
|
351 |
+
if is_sxm():
|
352 |
+
return "H100-SXM"
|
353 |
+
else:
|
354 |
+
return "H100-PCIe"
|
355 |
+
elif match("L40S", device_name):
|
356 |
+
return "L40S"
|
357 |
+
elif match("L40", device_name):
|
358 |
+
return "L40"
|
359 |
+
elif match("L4", device_name):
|
360 |
+
return "L4"
|
361 |
+
elif match("V100S", device_name):
|
362 |
+
return "V100S-PCIe"
|
363 |
+
elif match("V100", device_name):
|
364 |
+
if is_sxm():
|
365 |
+
if is_32gb():
|
366 |
+
return "V100-SXM-32GB"
|
367 |
+
else:
|
368 |
+
return "V100-SXM-16GB"
|
369 |
+
else:
|
370 |
+
if is_32gb():
|
371 |
+
return "V100-PCIe-32GB"
|
372 |
+
else:
|
373 |
+
return "V100-PCIe-16GB"
|
374 |
+
return None
|
375 |
+
|
376 |
+
|
377 |
+
def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput:
|
378 |
+
ipc_table = {
|
379 |
+
(9, 0):
|
380 |
+
MathThroughput(
|
381 |
+
int8=16384,
|
382 |
+
fp8=16384,
|
383 |
+
float16=8192,
|
384 |
+
bfloat16=8192,
|
385 |
+
float32=4096,
|
386 |
+
),
|
387 |
+
(8, 0):
|
388 |
+
MathThroughput(
|
389 |
+
int4=8192,
|
390 |
+
int8=4096,
|
391 |
+
float16=2048,
|
392 |
+
bfloat16=2048,
|
393 |
+
float32=1024,
|
394 |
+
),
|
395 |
+
(8, 6):
|
396 |
+
MathThroughput(
|
397 |
+
int4=4096,
|
398 |
+
int8=2048,
|
399 |
+
float16=1024,
|
400 |
+
bfloat16=1024,
|
401 |
+
float32=512,
|
402 |
+
),
|
403 |
+
(8, 9):
|
404 |
+
MathThroughput(
|
405 |
+
int4=2048,
|
406 |
+
int8=1024,
|
407 |
+
fp8=1024,
|
408 |
+
float16=512,
|
409 |
+
bfloat16=512,
|
410 |
+
float32=256,
|
411 |
+
),
|
412 |
+
(7, 0):
|
413 |
+
MathThroughput(
|
414 |
+
float16=1024,
|
415 |
+
float32=128,
|
416 |
+
),
|
417 |
+
(7, 5):
|
418 |
+
MathThroughput(
|
419 |
+
int4=4096,
|
420 |
+
int8=2048,
|
421 |
+
float16=1024,
|
422 |
+
float32=128,
|
423 |
+
),
|
424 |
+
}
|
425 |
+
return ipc_table.get(compute_cap, MathThroughput())
|
426 |
+
|
427 |
+
|
428 |
+
def nvlink_version(version_enum: int) -> int:
|
429 |
+
nvl_version_table = {
|
430 |
+
1: 1,
|
431 |
+
2: 2,
|
432 |
+
3: 2,
|
433 |
+
4: 2,
|
434 |
+
5: 3,
|
435 |
+
6: 3,
|
436 |
+
7: 4,
|
437 |
+
}
|
438 |
+
return nvl_version_table[version_enum]
|
439 |
+
|
440 |
+
|
441 |
+
def nvlink_bandwidth(nvlink_version: int) -> int:
|
442 |
+
nvl_bw_table = {
|
443 |
+
1: 80,
|
444 |
+
2: 150,
|
445 |
+
3: 300,
|
446 |
+
4: 450,
|
447 |
+
}
|
448 |
+
return nvl_bw_table[nvlink_version]
|
449 |
+
|
450 |
+
|
451 |
+
def infer_cluster_info() -> ClusterInfo:
|
452 |
+
device = torch.cuda.current_device()
|
453 |
+
index = device.index if isinstance(device, torch.device) else device
|
454 |
+
with PyNVMLContext():
|
455 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
|
456 |
+
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
457 |
+
logger.info(f"Compute capability: {compute_cap}")
|
458 |
+
err, properties = cudart.cudaGetDeviceProperties(index)
|
459 |
+
sm_count = properties.multiProcessorCount
|
460 |
+
logger.info(f"SM count: {sm_count}")
|
461 |
+
sm_clock = pynvml.nvmlDeviceGetMaxClockInfo(
|
462 |
+
handle,
|
463 |
+
pynvml.NVML_CLOCK_SM,
|
464 |
+
)
|
465 |
+
logger.info(f"SM clock: {sm_clock} MHz")
|
466 |
+
math_throughput = MathThroughput.to_tflops(
|
467 |
+
ipc_per_sm(compute_cap),
|
468 |
+
sm_count,
|
469 |
+
sm_clock,
|
470 |
+
)
|
471 |
+
for name in math_throughput.__dataclass_fields__:
|
472 |
+
tflops = getattr(math_throughput, name)
|
473 |
+
logger.info(f"{name} TFLOPS: {tflops}")
|
474 |
+
|
475 |
+
mem_info = _device_get_memory_info_fn(handle)
|
476 |
+
memory_budget = mem_info.total // (1024**3)
|
477 |
+
logger.info(f"Total Memory: {memory_budget} GiB")
|
478 |
+
|
479 |
+
mem_clock = pynvml.nvmlDeviceGetMaxClockInfo(
|
480 |
+
handle,
|
481 |
+
pynvml.NVML_CLOCK_MEM,
|
482 |
+
)
|
483 |
+
logger.info(f"Memory clock: {mem_clock} MHz")
|
484 |
+
if pynvml.__version__ < '11.5.0':
|
485 |
+
mem_bus_width = properties.memoryBusWidth
|
486 |
+
else:
|
487 |
+
mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
|
488 |
+
logger.info(f"Memory bus width: {mem_bus_width}")
|
489 |
+
memory_bw = mem_bus_width * mem_clock * 2 // int(8e3)
|
490 |
+
logger.info(f"Memory bandwidth: {memory_bw} GB/s")
|
491 |
+
|
492 |
+
try:
|
493 |
+
is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0))
|
494 |
+
logger.info(f"NVLink is active: {is_nvl_active}")
|
495 |
+
except pynvml.NVMLError:
|
496 |
+
is_nvl_active = False
|
497 |
+
|
498 |
+
intra_node_sharp = False
|
499 |
+
if is_nvl_active:
|
500 |
+
nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0)
|
501 |
+
nvl_version = nvlink_version(nvl_version_enum)
|
502 |
+
logger.info(f"NVLink version: {nvl_version}")
|
503 |
+
nvl_bw = nvlink_bandwidth(nvl_version)
|
504 |
+
logger.info(f"NVLink bandwidth: {nvl_bw} GB/s")
|
505 |
+
intra_node_bw = nvl_bw
|
506 |
+
if nvl_version >= 4:
|
507 |
+
intra_node_sharp = True
|
508 |
+
else:
|
509 |
+
if pynvml.__version__ < '11.5.0':
|
510 |
+
pcie_gen = pynvml.nvmlDeviceGetCurrPcieLinkGeneration(handle)
|
511 |
+
pcie_speed = (2**pcie_gen) * 1000
|
512 |
+
else:
|
513 |
+
pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle)
|
514 |
+
logger.info(f"PCIe speed: {pcie_speed} Mbps")
|
515 |
+
pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle)
|
516 |
+
logger.info(f"PCIe link width: {pcie_link_width}")
|
517 |
+
pcie_bw = pcie_speed * pcie_link_width // int(8e3)
|
518 |
+
logger.info(f"PCIe bandwidth: {pcie_bw} GB/s")
|
519 |
+
intra_node_bw = pcie_bw
|
520 |
+
|
521 |
+
cluster_info = ClusterInfo(
|
522 |
+
math_throughput=math_throughput,
|
523 |
+
memory_bw=memory_bw,
|
524 |
+
memory_budget_per_device=memory_budget,
|
525 |
+
intra_node_bw_per_device=intra_node_bw,
|
526 |
+
intra_node_sharp=intra_node_sharp,
|
527 |
+
)
|
528 |
+
return cluster_info
|
529 |
+
|
530 |
+
|
531 |
+
def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]:
|
532 |
+
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
533 |
+
cluster_key = infer_cluster_key()
|
534 |
+
if cluster_key is not None:
|
535 |
+
return dict(cluster_key=cluster_key)
|
536 |
+
else:
|
537 |
+
try:
|
538 |
+
cluster_info = infer_cluster_info()
|
539 |
+
except pynvml.NVMLError:
|
540 |
+
fallback_cluster_key = "L40"
|
541 |
+
cluster_info = copy.copy(cluster_infos[fallback_cluster_key])
|
542 |
+
memory_budget = torch.cuda.mem_get_info()[1] // (1024**3)
|
543 |
+
cluster_info.memory_budget_per_device = memory_budget
|
544 |
+
logger.warning(
|
545 |
+
f"Failed to infer cluster info for {device_name}, "
|
546 |
+
f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. "
|
547 |
+
"This setting makes no effect if you do not use auto parallel.")
|
548 |
+
return dict(
|
549 |
+
cluster_key=device_name.replace(" ", "-"),
|
550 |
+
cluster_info=cluster_info,
|
551 |
+
)
|
552 |
+
|
553 |
+
|
554 |
+
if __name__ == "__main__":
|
555 |
+
logger.set_level("info")
|
556 |
+
infer_cluster_info()
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from enum import auto
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
from strenum import LowercaseStrEnum
|
6 |
+
|
7 |
+
from tensorrt_llm._utils import BaseEnumMeta, DictConversion
|
8 |
+
|
9 |
+
from .cluster_info import ClusterInfo, cluster_infos
|
10 |
+
|
11 |
+
|
12 |
+
class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta):
|
13 |
+
ALPHA_BETA = auto()
|
14 |
+
PROFILE = auto()
|
15 |
+
S_CURVE = auto()
|
16 |
+
# Zero cost model is for test purpose.
|
17 |
+
# Use zero cost model for communication will make solver prefer sharding
|
18 |
+
# Use zero cost model for computation will make solver prefer replication
|
19 |
+
ZERO = auto()
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class AutoParallelConfig(DictConversion):
|
24 |
+
# cluster configuration
|
25 |
+
world_size: int = 1
|
26 |
+
gpus_per_node: int = 8
|
27 |
+
cluster_key: str = None
|
28 |
+
cluster_info: Optional[ClusterInfo] = None
|
29 |
+
|
30 |
+
# cost model configuration
|
31 |
+
sharding_cost_model: str = CostModel.ALPHA_BETA
|
32 |
+
comm_cost_model: str = CostModel.ALPHA_BETA
|
33 |
+
|
34 |
+
# strategy configuration
|
35 |
+
enable_pipeline_parallelism: bool = False
|
36 |
+
enable_shard_unbalanced_shape: bool = False
|
37 |
+
enable_shard_dynamic_shape: bool = False
|
38 |
+
enable_reduce_scatter: bool = True
|
39 |
+
|
40 |
+
# parallelization configuration
|
41 |
+
builder_flags: Optional[int] = None
|
42 |
+
debug_mode: bool = False
|
43 |
+
infer_shape: bool = True
|
44 |
+
validation_mode: bool = False
|
45 |
+
same_buffer_io: Dict[str, str] = field(default_factory=dict)
|
46 |
+
same_spec_io: Dict[str, str] = field(default_factory=dict)
|
47 |
+
sharded_io_allowlist: List[str] = field(default_factory=list)
|
48 |
+
fill_weights: bool = False
|
49 |
+
|
50 |
+
# debug configuration
|
51 |
+
parallel_config_cache: Optional[str] = None
|
52 |
+
profile_cache: Optional[str] = None
|
53 |
+
dump_path: Optional[str] = None
|
54 |
+
debug_outputs: Union[List[str], str] = field(default_factory=list)
|
55 |
+
|
56 |
+
def get_cluster_info(self) -> ClusterInfo:
|
57 |
+
return self.cluster_info or cluster_infos[self.cluster_key]
|
58 |
+
|
59 |
+
@property
|
60 |
+
def enabled(self) -> bool:
|
61 |
+
return self.world_size > 1
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/device_mesh.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from filelock import FileLock
|
9 |
+
|
10 |
+
from .config import AutoParallelConfig, CostModel
|
11 |
+
from .tensor_parallel.shape_consistency import ShapeConsistencyManager
|
12 |
+
|
13 |
+
|
14 |
+
class ProfileDB(ABC):
|
15 |
+
"""A database that stores profiling results for multiple device mesh
|
16 |
+
shapes."""
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def query(self, cluster_key, data_key):
|
20 |
+
...
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def update(self, cluster_key, data_key, mesh_result):
|
24 |
+
...
|
25 |
+
|
26 |
+
def close(self):
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
class MemDB(ProfileDB):
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
self.data = {}
|
34 |
+
|
35 |
+
def query(self, cluster_key, data_key):
|
36 |
+
key = (cluster_key, data_key)
|
37 |
+
mesh_result = self.data.get(key, None)
|
38 |
+
if mesh_result is None:
|
39 |
+
return None
|
40 |
+
else:
|
41 |
+
return mesh_result[0]
|
42 |
+
|
43 |
+
def update(self, cluster_key, data_key, mesh_result):
|
44 |
+
key = (cluster_key, data_key)
|
45 |
+
self.data[key] = mesh_result
|
46 |
+
|
47 |
+
|
48 |
+
class Hdf5DB(ProfileDB):
|
49 |
+
|
50 |
+
def __init__(self, name):
|
51 |
+
self.name = name
|
52 |
+
lock_name = self.name + ".lock"
|
53 |
+
self.lock = FileLock(lock_name, thread_local=False)
|
54 |
+
|
55 |
+
def query(self, cluster_key, data_key):
|
56 |
+
file_name = f"{self.name}.hdf5"
|
57 |
+
key = str((cluster_key, data_key))
|
58 |
+
self.lock.acquire()
|
59 |
+
mesh_result = None
|
60 |
+
with h5py.File(file_name, 'a') as f:
|
61 |
+
if key in f:
|
62 |
+
self.lock.release()
|
63 |
+
mesh_result = f[key]
|
64 |
+
return mesh_result[0]
|
65 |
+
else:
|
66 |
+
return None
|
67 |
+
|
68 |
+
def update(self, cluster_key, data_key, mesh_result):
|
69 |
+
key = str((cluster_key, data_key))
|
70 |
+
file_name = f"{self.name}.hdf5"
|
71 |
+
with h5py.File(file_name, 'a') as f:
|
72 |
+
f[key] = mesh_result
|
73 |
+
|
74 |
+
def close(self):
|
75 |
+
self.lock.release(force=True)
|
76 |
+
|
77 |
+
|
78 |
+
class LogicalDeviceMesh(object):
|
79 |
+
|
80 |
+
def __init__(self,
|
81 |
+
phy_mesh_shape,
|
82 |
+
mesh_shape,
|
83 |
+
phy_ids,
|
84 |
+
config: AutoParallelConfig,
|
85 |
+
alpha,
|
86 |
+
beta,
|
87 |
+
sharp,
|
88 |
+
prof_database=None,
|
89 |
+
shape_consistency_manager=None,
|
90 |
+
host_ips=None):
|
91 |
+
self.phy_mesh_shape = phy_mesh_shape
|
92 |
+
self.mesh_shape = mesh_shape
|
93 |
+
self.phy_ids = phy_ids
|
94 |
+
self.host_ips = host_ips
|
95 |
+
self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join(
|
96 |
+
[str(i) for i in mesh_shape]))
|
97 |
+
self.prof_min_max_size = [1, 2**34]
|
98 |
+
self.prof_comm_dtypes = [
|
99 |
+
"int8", "uint8", "int32", "uint32", "int64", "uint64", "float16",
|
100 |
+
"float32", "float64", "bfloat16"
|
101 |
+
]
|
102 |
+
self.devices_group = {
|
103 |
+
(0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1],
|
104 |
+
(1, ): [self.phy_ids, self.mesh_shape[1]],
|
105 |
+
(0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0]
|
106 |
+
}
|
107 |
+
self.prof_database = prof_database
|
108 |
+
self.shape_consistency_manager = shape_consistency_manager
|
109 |
+
self.config = config
|
110 |
+
self.cluster_info = config.get_cluster_info()
|
111 |
+
self.hw_alpha = alpha
|
112 |
+
self.hw_beta = beta
|
113 |
+
self.hw_sharp = sharp
|
114 |
+
self.algo_alpha_beta = self._estimate_algo_alpha_beta()
|
115 |
+
self.comm_op_to_nccl_test_func_name = {
|
116 |
+
'all_reduce': 'all_reduce_perf_mpi',
|
117 |
+
'all_gather': 'all_gather_perf_mpi',
|
118 |
+
'all_to_all': 'alltoall_perf_mpi',
|
119 |
+
'reduce_scatter': 'reduce_scatter_perf_mpi',
|
120 |
+
'split': 'split',
|
121 |
+
}
|
122 |
+
|
123 |
+
@property
|
124 |
+
def size(self) -> int:
|
125 |
+
return self.phy_ids.size
|
126 |
+
|
127 |
+
def _estimate_algo_alpha_beta(self):
|
128 |
+
ret = {}
|
129 |
+
ar_alpha, ar_beta = {}, {}
|
130 |
+
ag_alpha, ag_beta = {}, {}
|
131 |
+
rs_alpha, rs_beta = {}, {}
|
132 |
+
a2a_alpha, a2a_beta = {}, {}
|
133 |
+
phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape
|
134 |
+
if phy_num_hosts == 1 or phy_num_devices_per_host == 1:
|
135 |
+
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
136 |
+
num_devices = 1
|
137 |
+
for dim in dims:
|
138 |
+
num_devices = self.mesh_shape[dim] * num_devices
|
139 |
+
if num_devices != 1:
|
140 |
+
ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[
|
141 |
+
0] else self.hw_alpha[0] * num_devices / 2 / (
|
142 |
+
num_devices - 1)
|
143 |
+
ar_beta[dims] = self.hw_beta[0]
|
144 |
+
ag_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
145 |
+
num_devices - 1)
|
146 |
+
ag_beta[dims] = self.hw_beta[0]
|
147 |
+
rs_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
148 |
+
num_devices - 1)
|
149 |
+
rs_beta[dims] = self.hw_beta[0]
|
150 |
+
a2a_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
151 |
+
num_devices - 1)
|
152 |
+
a2a_beta[dims] = self.hw_beta[0]
|
153 |
+
# phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1
|
154 |
+
else:
|
155 |
+
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
156 |
+
num_devices = 1
|
157 |
+
for dim in dims:
|
158 |
+
num_devices = self.mesh_shape[dim] * num_devices
|
159 |
+
if num_devices != 1:
|
160 |
+
if len(dims) == 1:
|
161 |
+
dim = dims[0]
|
162 |
+
ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[
|
163 |
+
dim] else self.hw_alpha[dim] * num_devices / 2 / (
|
164 |
+
num_devices - 1)
|
165 |
+
ar_beta[dims] = self.hw_beta[dim]
|
166 |
+
ag_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
167 |
+
num_devices - 1)
|
168 |
+
ag_beta[dims] = self.hw_beta[dim]
|
169 |
+
rs_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
170 |
+
num_devices - 1)
|
171 |
+
rs_beta[dims] = self.hw_beta[dim]
|
172 |
+
a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
173 |
+
num_devices - 1)
|
174 |
+
a2a_beta[dims] = self.hw_beta[dim]
|
175 |
+
elif len(dims) == 2: # two level communication
|
176 |
+
num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host
|
177 |
+
inter_node_col_alpha = self.hw_alpha[
|
178 |
+
0] * num_devices_per_host
|
179 |
+
inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[
|
180 |
+
0] else inter_node_col_alpha * num_hosts / 2 / (
|
181 |
+
num_hosts - 1)
|
182 |
+
intra_node_ar_alpha = self.hw_alpha[1]
|
183 |
+
intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[
|
184 |
+
1] else intra_node_ar_alpha * num_devices_per_host / 2 / (
|
185 |
+
num_devices_per_host - 1)
|
186 |
+
ar_alpha[dims] = min(inter_node_ar_alpha,
|
187 |
+
intra_node_ar_alpha)
|
188 |
+
ar_beta[dims] = max(self.hw_beta)
|
189 |
+
ag_alpha[dims] = min(
|
190 |
+
inter_node_col_alpha * num_hosts / (num_hosts - 1),
|
191 |
+
self.hw_alpha[1] * num_devices_per_host /
|
192 |
+
(num_devices_per_host - 1))
|
193 |
+
ag_beta[dims] = max(self.hw_beta)
|
194 |
+
rs_alpha[dims] = ag_alpha[dims]
|
195 |
+
rs_beta[dims] = ag_beta[dims]
|
196 |
+
a2a_alpha[dims] = min(
|
197 |
+
num_hosts * self.hw_alpha[0] / (num_hosts - 1),
|
198 |
+
self.hw_alpha[1] * num_hosts)
|
199 |
+
a2a_beta[dims] = max(self.hw_beta)
|
200 |
+
else:
|
201 |
+
pass
|
202 |
+
ret['all_to_all'] = [a2a_alpha, a2a_beta]
|
203 |
+
ret['all_reduce'] = [ar_alpha, ar_beta]
|
204 |
+
ret['all_gather'] = [ag_alpha, ag_beta]
|
205 |
+
ret['reduce_scatter'] = [rs_alpha, rs_beta]
|
206 |
+
ret['p2p_cross_device'] = [
|
207 |
+
self.cluster_info.intra_node_bw_per_device,
|
208 |
+
self.cluster_info.intra_node_latency
|
209 |
+
]
|
210 |
+
ret['p2p_cross_host'] = [
|
211 |
+
self.cluster_info.inter_node_bw_per_device,
|
212 |
+
self.cluster_info.inter_node_latency
|
213 |
+
]
|
214 |
+
return ret
|
215 |
+
|
216 |
+
#[ToDo][KDuan] stub functions here
|
217 |
+
def _profile_split(self, min_max_comm_size):
|
218 |
+
comm_size, elapsed_time = [], []
|
219 |
+
size = min_max_comm_size[0]
|
220 |
+
while size <= min_max_comm_size[1]:
|
221 |
+
time = size * 2 / self.cluster_info.memory_bw
|
222 |
+
comm_size.append(size)
|
223 |
+
elapsed_time.append(time)
|
224 |
+
size = size * 2
|
225 |
+
return np.array([comm_size, elapsed_time])
|
226 |
+
|
227 |
+
def _prase_nccl_test_results(self, f_nccl_test_out_log):
|
228 |
+
'''[ToDo][KDuan] There is some dtye that may not been supported by nccl test, using default dtype (float)'''
|
229 |
+
start_parse = False
|
230 |
+
comm_size, elapsed_time = [], []
|
231 |
+
try:
|
232 |
+
with open(f_nccl_test_out_log, 'r') as lines:
|
233 |
+
for line in lines:
|
234 |
+
if start_parse:
|
235 |
+
prof_data = re.split(r"[ ]+", line.strip())
|
236 |
+
if len(prof_data) != 13:
|
237 |
+
continue
|
238 |
+
comm_size.append(float(prof_data[0]))
|
239 |
+
elapsed_time.append(float(prof_data[5]))
|
240 |
+
if 'GB/s' in line and 'us' in line:
|
241 |
+
start_parse = True
|
242 |
+
except Exception:
|
243 |
+
print(f'failed to parse {f_nccl_test_out_log}')
|
244 |
+
return comm_size, elapsed_time
|
245 |
+
|
246 |
+
def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group,
|
247 |
+
func_name, step, workload_key):
|
248 |
+
|
249 |
+
if func_name == 'split':
|
250 |
+
if 2 == step:
|
251 |
+
return self._profile_split(min_max_comm_size)
|
252 |
+
else:
|
253 |
+
return None
|
254 |
+
workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}'
|
255 |
+
os.makedirs(workspace_dir, exist_ok=True)
|
256 |
+
outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err'
|
257 |
+
if 1 == step:
|
258 |
+
num_nodes = len(self.host_ips)
|
259 |
+
num_gpus = self.mesh_shape[0] * self.mesh_shape[1]
|
260 |
+
ntasks_per_node = num_gpus // num_nodes
|
261 |
+
nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format(
|
262 |
+
device_group[1], func_name, min_max_comm_size[0],
|
263 |
+
min_max_comm_size[1], dtype, 2)
|
264 |
+
sbatch_command = '#!/bin/bash\n'
|
265 |
+
sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition'])
|
266 |
+
sbatch_command += '#SBATCH -A {}\n'.format(self.config['account'])
|
267 |
+
sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname'])
|
268 |
+
sbatch_command += '#SBATCH -N {}\n'.format(num_nodes)
|
269 |
+
sbatch_command += '#SBATCH -t {}\n'.format(self.config['time'])
|
270 |
+
sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format(
|
271 |
+
ntasks_per_node)
|
272 |
+
sbatch_command += '#SBATCH --exclusive\n'
|
273 |
+
sbatch_command += '#SBATCH --mem=0\n'
|
274 |
+
sbatch_command += '#SBATCH --network=sharp\n'
|
275 |
+
sbatch_command += '#SBATCH --mail-type=FAIL\n'
|
276 |
+
srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format(
|
277 |
+
num_nodes, ntasks_per_node, outfile, errfile,
|
278 |
+
self.config['container'])
|
279 |
+
command = sbatch_command + srun_command + nccl_test_command
|
280 |
+
with open(workspace_dir + '/workload.sub', 'w') as f:
|
281 |
+
f.write(command)
|
282 |
+
with open('./preprofiling_step1.sh', 'a') as f:
|
283 |
+
f.write(f'sbatch {workspace_dir}/workload.sub\n')
|
284 |
+
return None
|
285 |
+
|
286 |
+
else:
|
287 |
+
comm_size, elapsed_time = self._prase_nccl_test_results(outfile)
|
288 |
+
if len(comm_size) < 2:
|
289 |
+
assert 0, 'the profiling for {} was failed at step1, please try again'.format(
|
290 |
+
workload_key)
|
291 |
+
else:
|
292 |
+
print(workload_key, comm_size, elapsed_time)
|
293 |
+
return np.array([comm_size, elapsed_time])
|
294 |
+
|
295 |
+
def _profile_single_comm_perf(self, device_group, comm_op, step, data_key):
|
296 |
+
results = {}
|
297 |
+
func_name = self.comm_op_to_nccl_test_func_name[comm_op]
|
298 |
+
for dtype in self.prof_comm_dtypes:
|
299 |
+
size_time = self._profile_with_nccl_test(
|
300 |
+
self.prof_min_max_size, dtype, device_group, func_name, step,
|
301 |
+
data_key + f'_dtype{dtype}')
|
302 |
+
results[dtype] = size_time
|
303 |
+
return results
|
304 |
+
|
305 |
+
def profile_all_comms_perf(self, step):
|
306 |
+
if self.mesh_shape == (1, 1):
|
307 |
+
return None
|
308 |
+
mesh_results = self.prof_database.query(self.cluster_key,
|
309 |
+
self.mesh_shape)
|
310 |
+
if mesh_results:
|
311 |
+
return mesh_results
|
312 |
+
|
313 |
+
mesh_results = {}
|
314 |
+
data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}'
|
315 |
+
for comm_op in [
|
316 |
+
'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter',
|
317 |
+
'split'
|
318 |
+
]:
|
319 |
+
comm_perf = {}
|
320 |
+
for dim, device_group in self.devices_group.items():
|
321 |
+
# don't need to profile for mesh dim == 1
|
322 |
+
if len(dim) == 1 and self.mesh_shape[dim[0]] == 1:
|
323 |
+
continue
|
324 |
+
|
325 |
+
comm_perf[dim] = self._profile_single_comm_perf(
|
326 |
+
device_group, comm_op, step, data_key +
|
327 |
+
'_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim))))
|
328 |
+
mesh_results[comm_op] = comm_perf
|
329 |
+
if 2 == step:
|
330 |
+
self.prof_database.update(self.cluster_key, self.mesh_shape,
|
331 |
+
mesh_results)
|
332 |
+
|
333 |
+
return mesh_results
|
334 |
+
|
335 |
+
def _model_comm_cost_from_s_curve(self, size_time_array, realsize):
|
336 |
+
assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\
|
337 |
+
'the comm_size: {} is not in the profile range: [{}{}]'\
|
338 |
+
.format(realsize, size_time_array[0][0], size_time_array[0][-1])
|
339 |
+
return np.interp(realsize, size_time_array[0], size_time_array[1])
|
340 |
+
|
341 |
+
def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes):
|
342 |
+
elapsed_time = 0.0
|
343 |
+
if 'split' == comm_op:
|
344 |
+
elapsed_time = size_in_bytes * 2 / (
|
345 |
+
self.cluster_info.memory_bw *
|
346 |
+
self.cluster_info.memory_efficiency) * 1e-3
|
347 |
+
else:
|
348 |
+
dict_alpha, dict_beta = self.algo_alpha_beta[comm_op]
|
349 |
+
alpha, beta = dict_alpha[dim_key], dict_beta[dim_key]
|
350 |
+
elapsed_time = (size_in_bytes /
|
351 |
+
(alpha * self.cluster_info.communication_efficiency)
|
352 |
+
* 1e-3) + beta
|
353 |
+
return elapsed_time
|
354 |
+
|
355 |
+
def _input_size_to_comm_size(self, comm_op, dims, input_size):
|
356 |
+
ret = input_size
|
357 |
+
if 'all_gather' == comm_op:
|
358 |
+
for dim in dims:
|
359 |
+
ret = ret * self.mesh_shape[dim]
|
360 |
+
return ret
|
361 |
+
|
362 |
+
def estimate_comm_cost(self, comm_op, dim, input_size, dtype):
|
363 |
+
|
364 |
+
size = self._input_size_to_comm_size(comm_op, dim, input_size)
|
365 |
+
if self.config.comm_cost_model == CostModel.S_CURVE:
|
366 |
+
mesh_perf = self.prof_database.query(self.cluster_key,
|
367 |
+
self.mesh_shape)
|
368 |
+
assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format(
|
369 |
+
self.mesh_shape)
|
370 |
+
comm_op_perf = mesh_perf.get(comm_op, None)
|
371 |
+
assert comm_op_perf is not None, '{} is not profiled'.format(
|
372 |
+
comm_op)
|
373 |
+
elapsed_time = self._model_comm_cost_from_s_curve(
|
374 |
+
comm_op_perf[tuple(dim)][dtype], size)
|
375 |
+
return elapsed_time
|
376 |
+
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
377 |
+
elapsed_time = self._model_comm_cost_from_alpha_beta(
|
378 |
+
comm_op, tuple(dim), size)
|
379 |
+
elif self.config.comm_cost_model == CostModel.PROFILE:
|
380 |
+
assert False, 'Unsupported profile based communication cost model now'
|
381 |
+
elif self.config.comm_cost_model == CostModel.ZERO:
|
382 |
+
elapsed_time = 0.0
|
383 |
+
|
384 |
+
return elapsed_time # us
|
385 |
+
|
386 |
+
|
387 |
+
class PhysicalDeviceMesh(object):
|
388 |
+
|
389 |
+
def __init__(self,
|
390 |
+
phy_devices_id,
|
391 |
+
config: AutoParallelConfig,
|
392 |
+
prof_database=None,
|
393 |
+
shape_consistency_manager=None,
|
394 |
+
host_ips=None):
|
395 |
+
self.phy_devices_id = np.array(phy_devices_id)
|
396 |
+
self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape
|
397 |
+
self.host_ips = host_ips
|
398 |
+
if host_ips is None:
|
399 |
+
self.host_ips = [''] * self.num_hosts
|
400 |
+
self.config = config
|
401 |
+
self.cluster_info = config.get_cluster_info()
|
402 |
+
self.prof_database: ProfileDB = prof_database
|
403 |
+
self.shape_consistency_manager = shape_consistency_manager
|
404 |
+
if self.config.comm_cost_model not in CostModel:
|
405 |
+
raise ValueError(
|
406 |
+
f'unsupported communication cost model: {self.config.comm_cost_model}'
|
407 |
+
)
|
408 |
+
if self.config.sharding_cost_model not in CostModel:
|
409 |
+
raise ValueError(
|
410 |
+
f'unsupported sharding cost model: {self.config.sharding_cost_model}'
|
411 |
+
)
|
412 |
+
if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE:
|
413 |
+
if self.prof_database is None:
|
414 |
+
profile_cache = config.profile_cache
|
415 |
+
if profile_cache is None:
|
416 |
+
self.prof_database = MemDB()
|
417 |
+
else:
|
418 |
+
self.prof_database = Hdf5DB(profile_cache)
|
419 |
+
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
420 |
+
assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method'
|
421 |
+
assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method'
|
422 |
+
if self.config.sharding_cost_model == CostModel.ALPHA_BETA:
|
423 |
+
assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method'
|
424 |
+
|
425 |
+
if not shape_consistency_manager:
|
426 |
+
self.shape_consistency_manager = ShapeConsistencyManager()
|
427 |
+
|
428 |
+
@property
|
429 |
+
def size(self) -> int:
|
430 |
+
return self.phy_devices_id.size
|
431 |
+
|
432 |
+
def close(self):
|
433 |
+
if self.prof_database is not None:
|
434 |
+
self.prof_database.close()
|
435 |
+
|
436 |
+
def split_pipeline_meshes(
|
437 |
+
self, num_stages,
|
438 |
+
num_devices_per_stage) -> List["PhysicalDeviceMesh"]:
|
439 |
+
sub_meshes = []
|
440 |
+
if num_devices_per_stage <= self.num_devices_per_host:
|
441 |
+
assert self.num_devices_per_host % num_devices_per_stage == 0, \
|
442 |
+
"num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\
|
443 |
+
.format(self.num_devices_per_host, num_devices_per_stage)
|
444 |
+
num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage
|
445 |
+
num_clusters = self.num_hosts * num_clusters_per_host
|
446 |
+
assert num_stages % num_clusters == 0, \
|
447 |
+
"num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters)
|
448 |
+
for mesh_id in range(num_stages):
|
449 |
+
cluster_id = mesh_id % num_clusters
|
450 |
+
cluster_col = cluster_id % num_clusters_per_host
|
451 |
+
cluster_row = cluster_id // num_clusters_per_host
|
452 |
+
sub_devices_id = [
|
453 |
+
self.phy_devices_id[cluster_row][cluster_col *
|
454 |
+
num_devices_per_stage:(
|
455 |
+
(cluster_col + 1) *
|
456 |
+
num_devices_per_stage)]
|
457 |
+
]
|
458 |
+
sub_meshes.append(
|
459 |
+
PhysicalDeviceMesh(sub_devices_id, self.config,
|
460 |
+
self.prof_database,
|
461 |
+
self.shape_consistency_manager,
|
462 |
+
[self.host_ips[cluster_row]]))
|
463 |
+
else:
|
464 |
+
assert num_devices_per_stage % self.num_devices_per_host == 0, \
|
465 |
+
"num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\
|
466 |
+
.format(num_devices_per_stage, self.num_devices_per_host)
|
467 |
+
num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host
|
468 |
+
assert self.num_hosts % num_host_per_cluster == 0, \
|
469 |
+
"num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster)
|
470 |
+
num_clusters = self.num_hosts // num_host_per_cluster
|
471 |
+
for mesh_id in range(num_stages):
|
472 |
+
cluster_id = mesh_id % num_clusters
|
473 |
+
cluster_row = cluster_id * num_host_per_cluster
|
474 |
+
sub_devices_id = self.phy_devices_id[cluster_row:(
|
475 |
+
cluster_row + num_host_per_cluster)]
|
476 |
+
host_ips = self.host_ips[cluster_row:(cluster_row +
|
477 |
+
num_host_per_cluster)]
|
478 |
+
sub_meshes.append(
|
479 |
+
PhysicalDeviceMesh(sub_devices_id, self.config,
|
480 |
+
self.prof_database,
|
481 |
+
self.shape_consistency_manager,
|
482 |
+
host_ips))
|
483 |
+
return sub_meshes
|
484 |
+
|
485 |
+
def _profile_logical_meshes(self, logical_meshes, step):
|
486 |
+
for lmesh in logical_meshes:
|
487 |
+
lmesh.profile_all_comms_perf(step)
|
488 |
+
|
489 |
+
def as_logical_mesh(self) -> LogicalDeviceMesh:
|
490 |
+
alpha = [
|
491 |
+
self.cluster_info.inter_node_bw_per_device,
|
492 |
+
self.cluster_info.intra_node_bw_per_device
|
493 |
+
]
|
494 |
+
beta = [
|
495 |
+
self.cluster_info.inter_node_latency,
|
496 |
+
self.cluster_info.intra_node_latency
|
497 |
+
]
|
498 |
+
sharp = [
|
499 |
+
self.cluster_info.inter_node_sharp,
|
500 |
+
self.cluster_info.intra_node_sharp
|
501 |
+
]
|
502 |
+
return LogicalDeviceMesh(
|
503 |
+
self.phy_devices_id.shape,
|
504 |
+
self.phy_devices_id.shape,
|
505 |
+
self.phy_devices_id,
|
506 |
+
self.config,
|
507 |
+
alpha,
|
508 |
+
beta,
|
509 |
+
sharp,
|
510 |
+
self.prof_database,
|
511 |
+
self.shape_consistency_manager,
|
512 |
+
self.host_ips,
|
513 |
+
)
|
514 |
+
|
515 |
+
def get_logical_meshes(self):
|
516 |
+
logical_meshes = []
|
517 |
+
# (1, 2) -> (1, 2)
|
518 |
+
# (1, 4) -> (2, 2)
|
519 |
+
# (1, 8) -> (2, 4)
|
520 |
+
# (1, 16) -> (2, 8), (4, 4)
|
521 |
+
# (1, 32) -> (2, 16), (4, 8)
|
522 |
+
# (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8)
|
523 |
+
# (1, 64) -> (2, 32), (4, 16), (8, 8)
|
524 |
+
# we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2)
|
525 |
+
# we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1)
|
526 |
+
if self.num_hosts == 1:
|
527 |
+
alpha = [self.cluster_info.intra_node_bw_per_device]
|
528 |
+
beta = [self.cluster_info.intra_node_latency]
|
529 |
+
sharp = [self.cluster_info.intra_node_sharp]
|
530 |
+
for i in range(2, self.num_devices_per_host):
|
531 |
+
if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host:
|
532 |
+
lmesh_shape = (i, self.num_devices_per_host // i)
|
533 |
+
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
534 |
+
logical_meshes.append(
|
535 |
+
LogicalDeviceMesh(self.phy_devices_id.shape,
|
536 |
+
lmesh_shape, lmesh_phy_ids,
|
537 |
+
self.config, alpha, beta, sharp,
|
538 |
+
self.prof_database,
|
539 |
+
self.shape_consistency_manager,
|
540 |
+
self.host_ips))
|
541 |
+
# (8, 1) -> (2, 4)
|
542 |
+
# (16, 1) -> (2, 8), (4, 4)
|
543 |
+
elif self.num_devices_per_host == 1:
|
544 |
+
alpha = [self.cluster_info.inter_node_bw_per_device]
|
545 |
+
beta = [self.cluster_info.inter_node_latency]
|
546 |
+
sharp = [self.cluster_info.inter_node_sharp]
|
547 |
+
for i in range(2, self.num_hosts):
|
548 |
+
if self.num_hosts % i == 0 and i * i <= self.num_hosts:
|
549 |
+
lmesh_shape = (i, self.num_hosts // i)
|
550 |
+
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
551 |
+
logical_meshes.append(
|
552 |
+
LogicalDeviceMesh(self.phy_devices_id.shape,
|
553 |
+
lmesh_phy_ids, self.config, alpha,
|
554 |
+
beta, sharp, self.prof_database,
|
555 |
+
self.shape_consistency_manager,
|
556 |
+
self.host_ips))
|
557 |
+
# (2, 1) -> (2, 1)
|
558 |
+
# (2, 8) -> (2, 8)
|
559 |
+
# (1, 2) -> (1, 2)
|
560 |
+
# (1, 3) -> (1, 3)
|
561 |
+
# (1, 5) -> (1, 5)
|
562 |
+
if 0 == len(logical_meshes):
|
563 |
+
logical_meshes.append(self.as_logical_mesh())
|
564 |
+
return logical_meshes
|
565 |
+
|
566 |
+
'''
|
567 |
+
we assume we can evenly split the pipeline and deviceMesh
|
568 |
+
'''
|
569 |
+
|
570 |
+
def _list_all_sub_meshes(self):
|
571 |
+
sub_meshes = []
|
572 |
+
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
573 |
+
if self.num_devices_per_host % num_devices_per_stage == 0:
|
574 |
+
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
575 |
+
sub_meshes.append(
|
576 |
+
self.split_pipeline_meshes(num_stages,
|
577 |
+
num_devices_per_stage)[0])
|
578 |
+
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
579 |
+
if self.num_hosts % num_hosts_per_stage == 0:
|
580 |
+
num_stages = self.num_hosts // num_hosts_per_stage
|
581 |
+
sub_meshes.append(
|
582 |
+
self.split_pipeline_meshes(
|
583 |
+
num_stages,
|
584 |
+
num_hosts_per_stage * self.num_devices_per_host)[0])
|
585 |
+
return sub_meshes
|
586 |
+
|
587 |
+
def list_all_pipeline_configs(self):
|
588 |
+
configs = []
|
589 |
+
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
590 |
+
if self.num_devices_per_host % num_devices_per_stage == 0:
|
591 |
+
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
592 |
+
configs.append((num_stages, num_devices_per_stage))
|
593 |
+
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
594 |
+
if self.num_hosts % num_hosts_per_stage == 0:
|
595 |
+
num_stages = self.num_hosts // num_hosts_per_stage
|
596 |
+
configs.append(
|
597 |
+
(num_stages,
|
598 |
+
num_hosts_per_stage * self.num_devices_per_host))
|
599 |
+
return configs
|
600 |
+
|
601 |
+
def profile_s_curve(self, step):
|
602 |
+
sub_phy_device_meshes = self._list_all_sub_meshes()
|
603 |
+
for phy_mesh in sub_phy_device_meshes:
|
604 |
+
lmeshes = phy_mesh.get_logical_meshes()
|
605 |
+
self._profile_logical_meshes(lmeshes, step)
|
606 |
+
if 2 == step:
|
607 |
+
self.save_profile_database()
|
608 |
+
|
609 |
+
def profile_alpha_beta(self):
|
610 |
+
alpha = [250, 25]
|
611 |
+
beta = [100, 100]
|
612 |
+
return alpha, beta
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/node_graph.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import tensorrt as trt
|
5 |
+
|
6 |
+
from .pipeline_graph import PipelineGraph
|
7 |
+
from .runtime_profiling import RuntimeProfiler
|
8 |
+
from .simplifier import GraphConfig, StageType
|
9 |
+
from .solver import CostGraph, Solver
|
10 |
+
from .tensor_parallel.activation_node import Activation
|
11 |
+
from .tensor_parallel.assertion_node import Assertion
|
12 |
+
from .tensor_parallel.cast_node import Cast
|
13 |
+
from .tensor_parallel.concatenation_node import Concatenation
|
14 |
+
from .tensor_parallel.constant_node import Constant
|
15 |
+
from .tensor_parallel.elementwise_node import ElementWise
|
16 |
+
from .tensor_parallel.fill_node import Fill
|
17 |
+
from .tensor_parallel.gather_node import Gather
|
18 |
+
from .tensor_parallel.identity_node import Identity
|
19 |
+
from .tensor_parallel.input_node import InputNode
|
20 |
+
from .tensor_parallel.matmul_node import MatrixMultiply
|
21 |
+
from .tensor_parallel.node import Node
|
22 |
+
from .tensor_parallel.normalization_node import Normalization
|
23 |
+
from .tensor_parallel.output_node import OuputNode
|
24 |
+
from .tensor_parallel.p2p_node import P2PNode, P2PType
|
25 |
+
from .tensor_parallel.plugin_node import PluginNode
|
26 |
+
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
|
27 |
+
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
|
28 |
+
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
|
29 |
+
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
|
30 |
+
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
|
31 |
+
RMSnormPlugin)
|
32 |
+
from .tensor_parallel.reduce_node import Reduce
|
33 |
+
from .tensor_parallel.select_node import Select
|
34 |
+
from .tensor_parallel.shape_node import Shape
|
35 |
+
from .tensor_parallel.shuffle_node import Shuffle
|
36 |
+
from .tensor_parallel.slice_node import Slice
|
37 |
+
from .tensor_parallel.softmax_node import SoftMax
|
38 |
+
from .tensor_parallel.unary_node import Unary
|
39 |
+
|
40 |
+
LAYER_TYPE_2_NODE_TYPE = {
|
41 |
+
trt.LayerType.ACTIVATION: Activation,
|
42 |
+
trt.LayerType.ASSERTION: Assertion,
|
43 |
+
trt.LayerType.CAST: Cast,
|
44 |
+
trt.LayerType.CONCATENATION: Concatenation,
|
45 |
+
trt.LayerType.CONSTANT: Constant,
|
46 |
+
trt.LayerType.ELEMENTWISE: ElementWise,
|
47 |
+
trt.LayerType.FILL: Fill,
|
48 |
+
trt.LayerType.GATHER: Gather,
|
49 |
+
trt.LayerType.IDENTITY: Identity,
|
50 |
+
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
|
51 |
+
trt.LayerType.NORMALIZATION: Normalization,
|
52 |
+
trt.LayerType.PLUGIN_V2: PluginNode,
|
53 |
+
trt.LayerType.REDUCE: Reduce,
|
54 |
+
trt.LayerType.SELECT: Select,
|
55 |
+
trt.LayerType.SHAPE: Shape,
|
56 |
+
trt.LayerType.SHUFFLE: Shuffle,
|
57 |
+
trt.LayerType.SLICE: Slice,
|
58 |
+
trt.LayerType.SOFTMAX: SoftMax,
|
59 |
+
trt.LayerType.UNARY: Unary,
|
60 |
+
}
|
61 |
+
# TODO: BertAttention/All Quant plugins
|
62 |
+
PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
|
63 |
+
'GPTAttention': GPTAttentionPlugin,
|
64 |
+
'Gemm': GemmPlugin,
|
65 |
+
'Layernorm': LayernormPlugin,
|
66 |
+
'Rmsnorm': RMSnormPlugin,
|
67 |
+
'Lookup': LookupPlugin,
|
68 |
+
'Identity': IdentityPlugin,
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
class NodeGraph:
|
73 |
+
|
74 |
+
def __init__(self, graph: PipelineGraph):
|
75 |
+
self._nodes = {}
|
76 |
+
|
77 |
+
# construct nodes
|
78 |
+
for input in graph.inputs:
|
79 |
+
self._nodes[input.name] = InputNode(input)
|
80 |
+
for layer in graph.layers:
|
81 |
+
layer.to_base_class()
|
82 |
+
if "p2p_type" in layer.attrs:
|
83 |
+
self._nodes[layer.name] = P2PNode(layer)
|
84 |
+
elif layer.type == trt.LayerType.PLUGIN_V2:
|
85 |
+
layer.to_subclass()
|
86 |
+
plugin_type = layer.as_trt().plugin.plugin_type
|
87 |
+
layer.to_base_class()
|
88 |
+
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
|
89 |
+
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
|
90 |
+
else:
|
91 |
+
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
92 |
+
self._nodes[layer.name] = node
|
93 |
+
else:
|
94 |
+
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
95 |
+
self._nodes[layer.name] = node
|
96 |
+
for output in graph.outputs:
|
97 |
+
self._nodes[output.name] = OuputNode(output)
|
98 |
+
for node in self.nodes:
|
99 |
+
node.post_init(self)
|
100 |
+
node.node_runtime_profiler = RuntimeProfiler()
|
101 |
+
|
102 |
+
def get_node(self, name):
|
103 |
+
return self._nodes[name]
|
104 |
+
|
105 |
+
@property
|
106 |
+
def nodes(self) -> List[Node]:
|
107 |
+
return [*self._nodes.values()]
|
108 |
+
|
109 |
+
def assign_cost_weights(self, graph_config: GraphConfig):
|
110 |
+
layer_mapping = graph_config.graph_mapping.layer_mapping
|
111 |
+
for layer_name in layer_mapping.values():
|
112 |
+
node = self.get_node(layer_name)
|
113 |
+
node.sharding_weight += 1
|
114 |
+
node.resharding_weight += 1
|
115 |
+
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
|
116 |
+
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
|
117 |
+
node = self.get_node(layer_name)
|
118 |
+
same_spec_node = self.get_node(same_spec_layer_name)
|
119 |
+
same_spec_node.sharding_weight = node.sharding_weight
|
120 |
+
same_spec_node.resharding_weight = node.resharding_weight
|
121 |
+
|
122 |
+
def set_slowest_stage(self, stage_type: StageType,
|
123 |
+
graph_config: GraphConfig):
|
124 |
+
num_micro_batches = graph_config.num_micro_batches
|
125 |
+
block_per_stage = graph_config.num_blocks // graph_config.num_stages
|
126 |
+
block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
|
127 |
+
for node in self.nodes:
|
128 |
+
node.pipeline_weight = 0
|
129 |
+
node.cost_level = -1
|
130 |
+
if node.stage_type == StageType.START:
|
131 |
+
if stage_type == StageType.START:
|
132 |
+
node.pipeline_weight = num_micro_batches - 1
|
133 |
+
node.cost_level = 1
|
134 |
+
else:
|
135 |
+
node.cost_level = 0
|
136 |
+
if stage_type == StageType.START and node.in_start_block:
|
137 |
+
node.pipeline_weight = block_pipeline_weight
|
138 |
+
if node.stage_type == StageType.END:
|
139 |
+
if stage_type == StageType.END:
|
140 |
+
node.pipeline_weight = num_micro_batches - 1
|
141 |
+
node.cost_level = 1
|
142 |
+
else:
|
143 |
+
node.cost_level = 0
|
144 |
+
if stage_type == StageType.END and node.in_end_block:
|
145 |
+
node.pipeline_weight = block_pipeline_weight
|
146 |
+
if isinstance(node, P2PNode):
|
147 |
+
if (graph_config.has_cross_host
|
148 |
+
and node.p2p_type == P2PType.CROSS_HOST) or (
|
149 |
+
not graph_config.has_cross_host
|
150 |
+
and node.p2p_type == P2PType.CROSS_DEVICE):
|
151 |
+
if stage_type == StageType.BLOCK:
|
152 |
+
node.pipeline_weight += num_micro_batches - 1
|
153 |
+
node.cost_level = 1
|
154 |
+
else:
|
155 |
+
node.cost_level = 0
|
156 |
+
elif (graph_config.has_cross_device
|
157 |
+
and node.p2p_type == P2PType.CROSS_DEVICE) or (
|
158 |
+
not graph_config.has_cross_device
|
159 |
+
and node.p2p_type == P2PType.CROSS_HOST):
|
160 |
+
node.pipeline_weight += num_micro_batches - 1
|
161 |
+
if stage_type == StageType.BLOCK and node.in_slowest_block:
|
162 |
+
node.pipeline_weight = block_pipeline_weight
|
163 |
+
|
164 |
+
def get_cost_graph(self, lmesh):
|
165 |
+
leaf_strategies = []
|
166 |
+
for node in self.nodes:
|
167 |
+
if node.is_replicated:
|
168 |
+
node.set_strategy(None, lmesh)
|
169 |
+
else:
|
170 |
+
node.collect_strategies(lmesh)
|
171 |
+
for node in self.nodes:
|
172 |
+
strategies_vector = node.update_resharding_cost()
|
173 |
+
if len(strategies_vector) != 0:
|
174 |
+
leaf_strategies.append(strategies_vector)
|
175 |
+
cost_graph = CostGraph(leaf_strategies)
|
176 |
+
return cost_graph
|
177 |
+
|
178 |
+
def find_solution(self, cost_graph, memory_budget):
|
179 |
+
solver = Solver(cost_graph, memory_budget=memory_budget)
|
180 |
+
solution = solver.find_solution()[1]
|
181 |
+
|
182 |
+
graph_strategy = solution.node_best_strategy
|
183 |
+
for node_name, strategy in graph_strategy.items():
|
184 |
+
node = self._nodes[node_name]
|
185 |
+
for idx, pre_node in enumerate(node.predecessor_nodes):
|
186 |
+
if pre_node is None:
|
187 |
+
continue
|
188 |
+
if pre_node.node_name not in strategy.best_resharding_cost:
|
189 |
+
continue
|
190 |
+
strategy.best_resharding_cost[
|
191 |
+
idx] = strategy.best_resharding_cost[pre_node.node_name]
|
192 |
+
strategy.node_names[idx] = pre_node.node_name
|
193 |
+
for key in list(strategy.best_resharding_cost.keys()):
|
194 |
+
if isinstance(key, str):
|
195 |
+
del strategy.best_resharding_cost[key]
|
196 |
+
|
197 |
+
return solution
|
198 |
+
|
199 |
+
def visualize(self, name='pp_graph'):
|
200 |
+
with open(name + '.dot', 'w') as f:
|
201 |
+
f.write("digraph {\n")
|
202 |
+
'''
|
203 |
+
f.write(" // Value Nodes\n")
|
204 |
+
for name, tensor in self._tensors.items():
|
205 |
+
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
|
206 |
+
'''
|
207 |
+
f.write(" // Operation Nodes\n")
|
208 |
+
for name, node in self._nodes.items():
|
209 |
+
fillcolor = 'white'
|
210 |
+
if 'MATRIX_MULTIPLY' in name:
|
211 |
+
fillcolor = 'green'
|
212 |
+
label = name
|
213 |
+
if len(node.outputs) > 0:
|
214 |
+
label = name + '\\n' + str(node.outputs[0].shape)
|
215 |
+
f.write(
|
216 |
+
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
|
217 |
+
.format(name, fillcolor, label))
|
218 |
+
f.write(" // Edges\n")
|
219 |
+
for name, node in self._nodes.items():
|
220 |
+
for successor_node in node.successor_nodes:
|
221 |
+
if successor_node:
|
222 |
+
f.write(" \"{}\" ->\"{}\";\n".format(
|
223 |
+
name, successor_node.node_name))
|
224 |
+
f.write(" }\n")
|
225 |
+
|
226 |
+
def visualize_solution(self,
|
227 |
+
solution,
|
228 |
+
fname='pp_graph_solution',
|
229 |
+
ignore_shape_io=True):
|
230 |
+
with open(fname + '.dot', 'w') as f:
|
231 |
+
names, costs, block_ids = [], [], []
|
232 |
+
f.write("digraph {\n")
|
233 |
+
f.write(" // Operation Nodes\n")
|
234 |
+
for name, node in self._nodes.items():
|
235 |
+
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
236 |
+
continue
|
237 |
+
cost = 0.0
|
238 |
+
fillcolor = 'white'
|
239 |
+
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
|
240 |
+
fillcolor = 'orange'
|
241 |
+
elif '_same_spec' in name:
|
242 |
+
fillcolor = 'gray'
|
243 |
+
elif 'p2p_block' in name:
|
244 |
+
fillcolor = 'blue'
|
245 |
+
elif 'PLUGIN' in name:
|
246 |
+
fillcolor = 'yellow'
|
247 |
+
|
248 |
+
shape = 'box'
|
249 |
+
if 'output_node' == node.node_type or 'input_node' == node.node_type:
|
250 |
+
shape = 'ellipse'
|
251 |
+
fillcolor = 'green'
|
252 |
+
|
253 |
+
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
|
254 |
+
if len(node.inputs) > 0:
|
255 |
+
for idx, input in enumerate(node.inputs):
|
256 |
+
if not input:
|
257 |
+
continue
|
258 |
+
label = label + f'\\ninput{idx}_' + str(
|
259 |
+
input.shape) + f'_{input.dtype_str_size[0]}_'
|
260 |
+
if node.node_name in solution.node_best_strategy:
|
261 |
+
best_strategy = solution.node_best_strategy[
|
262 |
+
node.node_name]
|
263 |
+
shard_seq = str(
|
264 |
+
best_strategy.sharding_specs[f'input{idx}'].
|
265 |
+
sharding_sequence)
|
266 |
+
label = label + shard_seq
|
267 |
+
if idx not in best_strategy.best_resharding_cost:
|
268 |
+
continue
|
269 |
+
rcosts = best_strategy.best_resharding_cost[idx][0]
|
270 |
+
comm_action_sequence, resharding_cost = rcosts[
|
271 |
+
1], rcosts[2]
|
272 |
+
if len(comm_action_sequence) > 0:
|
273 |
+
label = label + '|'
|
274 |
+
for commspec in comm_action_sequence:
|
275 |
+
comm = [
|
276 |
+
commspec.comm_pattern, commspec.gather_dim,
|
277 |
+
commspec.shard_dim,
|
278 |
+
commspec.logical_process_axis
|
279 |
+
]
|
280 |
+
label = label + '->' + str(comm)
|
281 |
+
if resharding_cost > 0:
|
282 |
+
label = label + '_rcost{:.2}'.format(
|
283 |
+
resharding_cost)
|
284 |
+
cost = cost + resharding_cost
|
285 |
+
if len(node.outputs) > 0:
|
286 |
+
best_strategy = None
|
287 |
+
for idx, output in enumerate(node.outputs):
|
288 |
+
label = label + f'\\noutput{idx}_' + str(
|
289 |
+
output.shape) + f'_{output.dtype_str_size[0]}'
|
290 |
+
if node.node_name in solution.node_best_strategy:
|
291 |
+
best_strategy = solution.node_best_strategy[
|
292 |
+
node.node_name]
|
293 |
+
shard_seq = str(
|
294 |
+
best_strategy.sharding_specs[f'output{idx}'].
|
295 |
+
sharding_sequence)
|
296 |
+
comm = None
|
297 |
+
if f'output{idx}' in best_strategy.communication_actions:
|
298 |
+
commspec = best_strategy.communication_actions[
|
299 |
+
f'output{idx}']
|
300 |
+
comm = [
|
301 |
+
commspec.comm_pattern, commspec.gather_dim,
|
302 |
+
commspec.shard_dim,
|
303 |
+
commspec.logical_process_axis
|
304 |
+
]
|
305 |
+
label = label + '_' + shard_seq
|
306 |
+
if comm:
|
307 |
+
label = label + f' | {comm}'
|
308 |
+
if best_strategy:
|
309 |
+
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
|
310 |
+
label = label + '| scost{:.2}'.format(
|
311 |
+
best_strategy.sharding_cost)
|
312 |
+
if best_strategy.communication_cost > 0:
|
313 |
+
label = label + ' | ccost{:.2}'.format(
|
314 |
+
best_strategy.communication_cost)
|
315 |
+
names.append(name)
|
316 |
+
costs.append(cost)
|
317 |
+
block_ids.append([
|
318 |
+
node.building_block_id, node.cost_level,
|
319 |
+
node.sharding_weight + node.pipeline_weight,
|
320 |
+
node.same_spec_id
|
321 |
+
])
|
322 |
+
f.write(
|
323 |
+
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
|
324 |
+
.format(name, fillcolor, label, shape))
|
325 |
+
f.write(" // Edges\n")
|
326 |
+
for name, node in self._nodes.items():
|
327 |
+
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
328 |
+
continue
|
329 |
+
for successor_node in node.successor_nodes:
|
330 |
+
if successor_node:
|
331 |
+
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
|
332 |
+
continue
|
333 |
+
f.write(" \"{}\" ->\"{}\";\n".format(
|
334 |
+
name, successor_node.node_name))
|
335 |
+
f.write(" }\n")
|
336 |
+
df = pd.DataFrame.from_dict({
|
337 |
+
'node':
|
338 |
+
names,
|
339 |
+
'cost':
|
340 |
+
costs,
|
341 |
+
'block_id': [block[0] for block in block_ids],
|
342 |
+
'cost_level': [block[1] for block in block_ids],
|
343 |
+
'sharding_weight': [block[2] for block in block_ids],
|
344 |
+
'same_spec_id': [block[3] for block in block_ids]
|
345 |
+
})
|
346 |
+
df['weight_cost'] = df['sharding_weight'] * df['cost']
|
347 |
+
df.to_csv(fname + '.csv')
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/pipeline_graph.py
ADDED
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import tensorrt as trt
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch
|
9 |
+
from tensorrt_llm.logger import logger
|
10 |
+
from tensorrt_llm.network import Network, get_plugin_info, set_plugin_info
|
11 |
+
from tensorrt_llm.plugin.plugin import PluginConfig
|
12 |
+
from tensorrt_llm.runtime.session import Session
|
13 |
+
|
14 |
+
from .utils import (current_flags, get_builder_flags, get_sorted_layer_ids,
|
15 |
+
get_strongly_typed, get_trt_network, set_trt_network,
|
16 |
+
to_base_class_layer, to_subclass_layer)
|
17 |
+
|
18 |
+
|
19 |
+
class Tensor:
|
20 |
+
|
21 |
+
def __init__(self, graph: "PipelineGraph"):
|
22 |
+
self._graph = graph
|
23 |
+
self._trt = None
|
24 |
+
self._shape = None
|
25 |
+
self._max_shape = None
|
26 |
+
self._value = None
|
27 |
+
self.producer: Layer = None
|
28 |
+
self.output_index = None
|
29 |
+
self.consumers = []
|
30 |
+
self.graph_input_index = -1
|
31 |
+
self.graph_output_index = -1
|
32 |
+
self.attrs = {}
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor):
|
36 |
+
tensor = Tensor(graph)
|
37 |
+
tensor._trt = trt_tensor
|
38 |
+
return tensor
|
39 |
+
|
40 |
+
def as_trt(self) -> trt.ITensor:
|
41 |
+
return self._trt
|
42 |
+
|
43 |
+
def copy(self) -> "Tensor":
|
44 |
+
tensor = Tensor(self._graph)
|
45 |
+
tensor._trt = self._trt
|
46 |
+
tensor._shape = self._shape
|
47 |
+
tensor._max_shape = self._max_shape
|
48 |
+
tensor._value = self._value
|
49 |
+
tensor.producer = self.producer
|
50 |
+
tensor.output_index = self.output_index
|
51 |
+
tensor.consumers = [*self.consumers]
|
52 |
+
tensor.graph_input_index = self.graph_input_index
|
53 |
+
tensor.graph_output_index = self.graph_output_index
|
54 |
+
tensor.attrs = self.attrs.copy()
|
55 |
+
return tensor
|
56 |
+
|
57 |
+
@property
|
58 |
+
def graph(self) -> "PipelineGraph":
|
59 |
+
return self._graph
|
60 |
+
|
61 |
+
@property
|
62 |
+
def name(self) -> str:
|
63 |
+
return self._trt.name
|
64 |
+
|
65 |
+
@name.setter
|
66 |
+
def name(self, name: str):
|
67 |
+
old_name = self._trt.name
|
68 |
+
if name != old_name:
|
69 |
+
self._trt.name = name
|
70 |
+
self.graph._tensors[name] = self
|
71 |
+
del self.graph._tensors[old_name]
|
72 |
+
if self.is_graph_input:
|
73 |
+
self.graph._inputs[name] = self
|
74 |
+
del self.graph._inputs[old_name]
|
75 |
+
elif self.is_graph_output:
|
76 |
+
self.graph._outputs[name] = self
|
77 |
+
del self.graph._outputs[old_name]
|
78 |
+
|
79 |
+
@property
|
80 |
+
def shape(self):
|
81 |
+
return self._shape
|
82 |
+
|
83 |
+
@property
|
84 |
+
def max_shape(self):
|
85 |
+
return self._max_shape
|
86 |
+
|
87 |
+
@property
|
88 |
+
def raw_shape(self):
|
89 |
+
assert isinstance(self._trt, trt.ITensor)
|
90 |
+
return self._trt.shape
|
91 |
+
|
92 |
+
@shape.setter
|
93 |
+
def shape(self, shape):
|
94 |
+
self._shape = shape
|
95 |
+
|
96 |
+
@max_shape.setter
|
97 |
+
def max_shape(self, max_shape):
|
98 |
+
self._max_shape = max_shape
|
99 |
+
|
100 |
+
@raw_shape.setter
|
101 |
+
def raw_shape(self, raw_shape):
|
102 |
+
assert isinstance(self._trt, trt.ITensor)
|
103 |
+
self._trt.shape = raw_shape
|
104 |
+
|
105 |
+
@property
|
106 |
+
def value(self):
|
107 |
+
return self._value
|
108 |
+
|
109 |
+
@value.setter
|
110 |
+
def value(self, value):
|
111 |
+
self._value = value
|
112 |
+
|
113 |
+
@property
|
114 |
+
def dtype(self):
|
115 |
+
return self._trt.dtype
|
116 |
+
|
117 |
+
@property
|
118 |
+
def broadcast_across_batch(self):
|
119 |
+
return self._trt.broadcast_across_batch
|
120 |
+
|
121 |
+
@property
|
122 |
+
def dtype_size(self):
|
123 |
+
return self.dtype.itemsize
|
124 |
+
|
125 |
+
@property
|
126 |
+
def dtype_str(self):
|
127 |
+
return trt_dtype_to_str(self.dtype)
|
128 |
+
|
129 |
+
@property
|
130 |
+
def dtype_str_size(self):
|
131 |
+
return [trt_dtype_to_str(self.dtype), self.dtype.itemsize]
|
132 |
+
|
133 |
+
@property
|
134 |
+
def is_graph_input(self) -> bool:
|
135 |
+
return self.graph_input_index != -1
|
136 |
+
|
137 |
+
@property
|
138 |
+
def is_graph_output(self) -> bool:
|
139 |
+
return self.graph_output_index != -1
|
140 |
+
|
141 |
+
@property
|
142 |
+
def is_graph_io(self) -> bool:
|
143 |
+
return self.is_graph_input or self.is_graph_output
|
144 |
+
|
145 |
+
|
146 |
+
class Layer:
|
147 |
+
|
148 |
+
def __init__(self, graph):
|
149 |
+
self._graph = graph
|
150 |
+
self._trt = None
|
151 |
+
self._index = None
|
152 |
+
self._inputs = []
|
153 |
+
self._outputs = []
|
154 |
+
self._is_shape_io = False
|
155 |
+
self.attrs = {}
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def from_trt(graph, trt_layer, index):
|
159 |
+
layer = Layer(graph)
|
160 |
+
layer._trt = trt_layer
|
161 |
+
layer._index = index
|
162 |
+
for i in range(trt_layer.num_inputs):
|
163 |
+
input = trt_layer.get_input(i)
|
164 |
+
if input is not None:
|
165 |
+
layer._inputs.append(graph.get_tensor(input.name))
|
166 |
+
layer._inputs[i].consumers.append((layer, i))
|
167 |
+
else:
|
168 |
+
layer._inputs.append(None)
|
169 |
+
for i in range(trt_layer.num_outputs):
|
170 |
+
output = trt_layer.get_output(i)
|
171 |
+
layer._outputs.append(graph.get_tensor(output.name))
|
172 |
+
layer._outputs[i].producer = layer
|
173 |
+
layer._outputs[i].output_index = i
|
174 |
+
set_trt_network(trt_layer, graph.as_trt())
|
175 |
+
return layer
|
176 |
+
|
177 |
+
def as_trt(self) -> trt.ILayer:
|
178 |
+
return self._trt
|
179 |
+
|
180 |
+
@property
|
181 |
+
def graph(self) -> "PipelineGraph":
|
182 |
+
return self._graph
|
183 |
+
|
184 |
+
@property
|
185 |
+
def name(self) -> str:
|
186 |
+
return self._trt.name
|
187 |
+
|
188 |
+
@name.setter
|
189 |
+
def name(self, name: str):
|
190 |
+
old_name = self._trt.name
|
191 |
+
if name != old_name:
|
192 |
+
self._trt.name = name
|
193 |
+
self.graph._layers[name] = self
|
194 |
+
del self.graph._layers[old_name]
|
195 |
+
|
196 |
+
@property
|
197 |
+
def type(self) -> trt.LayerType:
|
198 |
+
return self._trt.type
|
199 |
+
|
200 |
+
@property
|
201 |
+
def index(self) -> int:
|
202 |
+
return self._index
|
203 |
+
|
204 |
+
@property
|
205 |
+
def inputs(self) -> List[Tensor]:
|
206 |
+
return self._inputs
|
207 |
+
|
208 |
+
@property
|
209 |
+
def outputs(self) -> List[Tensor]:
|
210 |
+
return self._outputs
|
211 |
+
|
212 |
+
def get_input(self, index: int) -> Tensor:
|
213 |
+
return self._inputs[index]
|
214 |
+
|
215 |
+
def get_output(self, index: int) -> Tensor:
|
216 |
+
return self._outputs[index]
|
217 |
+
|
218 |
+
@property
|
219 |
+
def num_inputs(self) -> int:
|
220 |
+
return self._trt.num_inputs
|
221 |
+
|
222 |
+
@property
|
223 |
+
def num_outputs(self) -> int:
|
224 |
+
return self._trt.num_outputs
|
225 |
+
|
226 |
+
@property
|
227 |
+
def is_shape_io(self) -> bool:
|
228 |
+
return self._is_shape_io
|
229 |
+
|
230 |
+
def to_subclass(self):
|
231 |
+
to_subclass_layer(self._trt)
|
232 |
+
|
233 |
+
def to_base_class(self):
|
234 |
+
to_base_class_layer(self._trt)
|
235 |
+
|
236 |
+
def assign_shapes(self, shapes, values):
|
237 |
+
for output in self.outputs:
|
238 |
+
output.shape = shapes[output.name]
|
239 |
+
output.value = values.get(output.name)
|
240 |
+
|
241 |
+
|
242 |
+
@dataclass
|
243 |
+
class GraphRunner:
|
244 |
+
session: Session
|
245 |
+
inputs: Dict[str, torch.Tensor]
|
246 |
+
outputs: Dict[str, torch.Tensor]
|
247 |
+
stream: torch.Stream
|
248 |
+
|
249 |
+
def run(self):
|
250 |
+
cuda_stream = self.stream.cuda_stream
|
251 |
+
assert self.session.run(self.inputs, self.outputs, cuda_stream)
|
252 |
+
self.stream.synchronize()
|
253 |
+
return self.outputs
|
254 |
+
|
255 |
+
|
256 |
+
class PipelineGraph:
|
257 |
+
|
258 |
+
def __init__(self):
|
259 |
+
self._trt = None
|
260 |
+
self._inputs: Dict[str, Tensor] = {}
|
261 |
+
self._outputs: Dict[str, Tensor] = {}
|
262 |
+
self._layers: Dict[str, Layer] = {}
|
263 |
+
self._tensors: Dict[str, Tensor] = {}
|
264 |
+
self._io_buffer_mapping = {}
|
265 |
+
self._unfilled_weights = {}
|
266 |
+
self._auto_parallel_config = None
|
267 |
+
self._plugin_config: PluginConfig = None
|
268 |
+
|
269 |
+
@staticmethod
|
270 |
+
def create_graph():
|
271 |
+
graph = PipelineGraph()
|
272 |
+
trt_builder = trt.Builder(logger.trt_logger)
|
273 |
+
explicit_batch_flag = 0
|
274 |
+
# Explicit batch flag will be deprecated in TRT 10
|
275 |
+
if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys(
|
276 |
+
):
|
277 |
+
explicit_batch_flag = 1 << int(
|
278 |
+
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
279 |
+
if get_strongly_typed():
|
280 |
+
network = trt_builder.create_network(
|
281 |
+
explicit_batch_flag
|
282 |
+
| (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)))
|
283 |
+
else:
|
284 |
+
network = trt_builder.create_network(explicit_batch_flag)
|
285 |
+
graph._trt = network
|
286 |
+
return graph
|
287 |
+
|
288 |
+
def _register_unfilled_weights(self, layer_name, weights, values):
|
289 |
+
self._unfilled_weights[layer_name] = (weights, values)
|
290 |
+
|
291 |
+
def _add_tensor(self, tensor, old_tensor, prefix):
|
292 |
+
if prefix is not None:
|
293 |
+
tensor.name = prefix + old_tensor.name
|
294 |
+
else:
|
295 |
+
tensor.name = old_tensor.name
|
296 |
+
tensor.location = old_tensor.location
|
297 |
+
if old_tensor.dynamic_range is not None:
|
298 |
+
tensor.dynamic_range = old_tensor.dynamic_range
|
299 |
+
if tensor.is_network_input:
|
300 |
+
tensor.shape = old_tensor.shape
|
301 |
+
for i in range(len(old_tensor.shape)):
|
302 |
+
name = old_tensor.get_dimension_name(i)
|
303 |
+
if name is not None:
|
304 |
+
tensor.set_dimension_name(i, name)
|
305 |
+
return self._register_tensor(tensor)
|
306 |
+
|
307 |
+
def _register_tensor(self, tensor):
|
308 |
+
wrapped_tensor = Tensor.from_trt(self, tensor)
|
309 |
+
assert tensor.name not in self._tensors
|
310 |
+
self._tensors[tensor.name] = wrapped_tensor
|
311 |
+
return wrapped_tensor
|
312 |
+
|
313 |
+
def add_input(self, tensor, prefix=None):
|
314 |
+
tensor_name = tensor.name
|
315 |
+
if prefix is not None:
|
316 |
+
tensor_name = prefix + tensor_name
|
317 |
+
input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape)
|
318 |
+
new_tensor = self._add_tensor(input, tensor, prefix)
|
319 |
+
new_tensor.graph_input_index = len(self._inputs)
|
320 |
+
self._inputs[tensor_name] = new_tensor
|
321 |
+
return new_tensor
|
322 |
+
|
323 |
+
def register_input(self, tensor, index=None):
|
324 |
+
if index is None:
|
325 |
+
index = self.num_inputs - 1
|
326 |
+
assert self._trt.get_input(index).name == tensor.name
|
327 |
+
wrapped_input = self._register_tensor(tensor)
|
328 |
+
wrapped_input.graph_input_index = index
|
329 |
+
self._inputs[tensor.name] = wrapped_input
|
330 |
+
return wrapped_input
|
331 |
+
|
332 |
+
def add_output(self, tensor, prefix=None):
|
333 |
+
tensor_name = tensor.name
|
334 |
+
if prefix is not None:
|
335 |
+
tensor_name = prefix + tensor_name
|
336 |
+
output = self.get_tensor(tensor_name)
|
337 |
+
output.graph_output_index = len(self._outputs)
|
338 |
+
trt_output = output.as_trt()
|
339 |
+
self._trt.mark_output(trt_output)
|
340 |
+
trt_output.dtype = tensor.dtype
|
341 |
+
self._outputs[tensor_name] = output
|
342 |
+
return output
|
343 |
+
|
344 |
+
def add_output_shape(self, tensor, prefix=None):
|
345 |
+
tensor_name = tensor.name
|
346 |
+
if prefix is not None:
|
347 |
+
tensor_name = prefix + tensor_name
|
348 |
+
output = self.get_tensor(tensor_name)
|
349 |
+
trt_output = output.as_trt()
|
350 |
+
self._trt.mark_output_for_shapes(trt_output)
|
351 |
+
trt_output.dtype = tensor.dtype
|
352 |
+
self._outputs[tensor_name] = output
|
353 |
+
return output
|
354 |
+
|
355 |
+
def add_layer(
|
356 |
+
self,
|
357 |
+
layer,
|
358 |
+
input_mapping=None,
|
359 |
+
prefix=None,
|
360 |
+
updated_attrs=None,
|
361 |
+
) -> Layer:
|
362 |
+
|
363 |
+
def get_input(i):
|
364 |
+
name = layer.get_input(i).name
|
365 |
+
if prefix is not None:
|
366 |
+
name = prefix + name
|
367 |
+
if input_mapping is not None and name in input_mapping:
|
368 |
+
name = input_mapping[name]
|
369 |
+
return self.get_tensor(name).as_trt()
|
370 |
+
|
371 |
+
network = self._trt
|
372 |
+
layer_type = layer.type
|
373 |
+
to_subclass_layer(layer)
|
374 |
+
if layer_type == trt.LayerType.ACTIVATION:
|
375 |
+
trt_input = get_input(0)
|
376 |
+
new_layer = network.add_activation(trt_input, layer.type)
|
377 |
+
new_layer.alpha = layer.alpha
|
378 |
+
new_layer.beta = layer.beta
|
379 |
+
elif layer_type == trt.LayerType.CONCATENATION:
|
380 |
+
trt_inputs = [get_input(i) for i in range(layer.num_inputs)]
|
381 |
+
new_layer = network.add_concatenation(trt_inputs)
|
382 |
+
new_layer.axis = layer.axis
|
383 |
+
elif layer_type == trt.LayerType.CONSTANT:
|
384 |
+
new_layer = network.add_constant(layer.shape, layer.weights)
|
385 |
+
elif layer_type == trt.LayerType.ELEMENTWISE:
|
386 |
+
new_layer = network.add_elementwise(get_input(0), get_input(1),
|
387 |
+
layer.op)
|
388 |
+
elif layer_type == trt.LayerType.FILL:
|
389 |
+
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
|
390 |
+
shape_input = get_input(0)
|
391 |
+
shape = [1]
|
392 |
+
else:
|
393 |
+
shape_input = None
|
394 |
+
shape = layer.shape
|
395 |
+
new_layer = network.add_fill(shape, layer.operation, layer.to_type)
|
396 |
+
if shape_input is not None:
|
397 |
+
new_layer.set_input(0, shape_input)
|
398 |
+
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
|
399 |
+
new_layer.set_input(0, get_input(0))
|
400 |
+
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
401 |
+
new_layer.set_input(1, get_input(1))
|
402 |
+
else:
|
403 |
+
new_layer.alpha = layer.alpha
|
404 |
+
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
405 |
+
new_layer.set_input(2, get_input(2))
|
406 |
+
else:
|
407 |
+
new_layer.beta = layer.beta
|
408 |
+
elif layer_type == trt.LayerType.GATHER:
|
409 |
+
trt_input = get_input(0)
|
410 |
+
trt_indices = get_input(1)
|
411 |
+
new_layer = network.add_gather_v2(trt_input, trt_indices,
|
412 |
+
layer.mode)
|
413 |
+
new_layer.axis = layer.axis
|
414 |
+
new_layer.num_elementwise_dims = layer.num_elementwise_dims
|
415 |
+
new_layer.mode = layer.mode
|
416 |
+
elif layer_type == trt.LayerType.MATRIX_MULTIPLY:
|
417 |
+
new_layer = network.add_matrix_multiply(get_input(0), layer.op0,
|
418 |
+
get_input(1), layer.op1)
|
419 |
+
elif layer_type == trt.LayerType.REDUCE:
|
420 |
+
new_layer = network.add_reduce(get_input(0), layer.op, layer.axes,
|
421 |
+
layer.keep_dims)
|
422 |
+
elif layer_type == trt.LayerType.SELECT:
|
423 |
+
trt_condition = get_input(0)
|
424 |
+
trt_then = get_input(1)
|
425 |
+
trt_else = get_input(2)
|
426 |
+
new_layer = network.add_select(trt_condition, trt_then, trt_else)
|
427 |
+
elif layer_type == trt.LayerType.SHUFFLE:
|
428 |
+
new_layer = network.add_shuffle(get_input(0))
|
429 |
+
new_layer.first_transpose = layer.first_transpose
|
430 |
+
new_layer.second_transpose = layer.second_transpose
|
431 |
+
new_layer.zero_is_placeholder = layer.zero_is_placeholder
|
432 |
+
if layer.num_inputs >= 2:
|
433 |
+
trt_reshape_dims_tensor = get_input(1)
|
434 |
+
new_layer.set_input(1, trt_reshape_dims_tensor)
|
435 |
+
else:
|
436 |
+
new_layer.reshape_dims = layer.reshape_dims
|
437 |
+
elif layer_type == trt.LayerType.SLICE:
|
438 |
+
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
439 |
+
trt_start = get_input(1)
|
440 |
+
start = []
|
441 |
+
else:
|
442 |
+
trt_start = None
|
443 |
+
start = layer.start
|
444 |
+
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
445 |
+
trt_shape = get_input(2)
|
446 |
+
shape = []
|
447 |
+
else:
|
448 |
+
trt_shape = None
|
449 |
+
shape = layer.shape
|
450 |
+
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
|
451 |
+
trt_stride = get_input(3)
|
452 |
+
stride = []
|
453 |
+
else:
|
454 |
+
trt_stride = None
|
455 |
+
stride = layer.stride
|
456 |
+
new_layer = network.add_slice(get_input(0), start, shape, stride)
|
457 |
+
new_layer.mode = layer.mode
|
458 |
+
if trt_start is not None:
|
459 |
+
new_layer.set_input(1, trt_start)
|
460 |
+
if trt_shape is not None:
|
461 |
+
new_layer.set_input(2, trt_shape)
|
462 |
+
if trt_stride is not None:
|
463 |
+
new_layer.set_input(3, trt_stride)
|
464 |
+
elif layer_type == trt.LayerType.SOFTMAX:
|
465 |
+
new_layer = network.add_softmax(get_input(0))
|
466 |
+
new_layer.axes = layer.axes
|
467 |
+
elif layer_type == trt.LayerType.UNARY:
|
468 |
+
new_layer = network.add_unary(get_input(0), layer.op)
|
469 |
+
elif layer_type == trt.LayerType.SHAPE:
|
470 |
+
new_layer = network.add_shape(get_input(0))
|
471 |
+
elif layer_type == trt.LayerType.ASSERTION:
|
472 |
+
new_layer = network.add_assertion(get_input(0), layer.message)
|
473 |
+
elif layer_type == trt.LayerType.CAST:
|
474 |
+
new_layer = network.add_cast(get_input(0), layer.to_type)
|
475 |
+
elif layer_type == trt.LayerType.NORMALIZATION:
|
476 |
+
trt_input = get_input(0)
|
477 |
+
trt_scale = get_input(1)
|
478 |
+
trt_bias = get_input(2)
|
479 |
+
new_layer = network.add_normalization(trt_input, trt_scale,
|
480 |
+
trt_bias, layer.axes)
|
481 |
+
new_layer.epsilon = layer.epsilon
|
482 |
+
new_layer.num_groups = layer.num_groups
|
483 |
+
new_layer.compute_precision = layer.compute_precision
|
484 |
+
elif layer_type == trt.LayerType.IDENTITY:
|
485 |
+
new_layer = network.add_identity(get_input(0))
|
486 |
+
elif layer_type == trt.LayerType.PLUGIN_V2:
|
487 |
+
plugin = layer.plugin
|
488 |
+
updated = False
|
489 |
+
if (updated_attrs is not None
|
490 |
+
and updated_attrs.get("plugin") is not None):
|
491 |
+
plugin = updated_attrs["plugin"]
|
492 |
+
updated = True
|
493 |
+
updated_attrs = None
|
494 |
+
new_layer = network.add_plugin_v2(
|
495 |
+
[get_input(i) for i in range(layer.num_inputs)],
|
496 |
+
plugin,
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
raise NotImplementedError(
|
500 |
+
"Unsupported layer type: {}".format(layer_type))
|
501 |
+
|
502 |
+
if updated_attrs is not None:
|
503 |
+
for attr_name, attr_value in updated_attrs.items():
|
504 |
+
setattr(new_layer, attr_name, attr_value)
|
505 |
+
|
506 |
+
to_base_class_layer(layer)
|
507 |
+
to_base_class_layer(new_layer)
|
508 |
+
layer_index = network.num_layers - 1
|
509 |
+
layer_name = layer.name
|
510 |
+
if prefix is not None:
|
511 |
+
layer_name = prefix + layer_name
|
512 |
+
new_layer.name = layer_name
|
513 |
+
new_layer.metadata = new_layer.name
|
514 |
+
if layer.precision_is_set:
|
515 |
+
new_layer.precision = layer.precision
|
516 |
+
for i in range(layer.num_outputs):
|
517 |
+
if layer.output_type_is_set(i):
|
518 |
+
new_layer.set_output_type(i, layer.get_output_type(i))
|
519 |
+
output = new_layer.get_output(i)
|
520 |
+
self._add_tensor(output, layer.get_output(i), prefix)
|
521 |
+
wrapped_layer = Layer.from_trt(self, new_layer, layer_index)
|
522 |
+
assert layer_name not in self._layers
|
523 |
+
self._layers[layer_name] = wrapped_layer
|
524 |
+
if layer_type == trt.LayerType.PLUGIN_V2:
|
525 |
+
if not updated:
|
526 |
+
plugin_info = get_plugin_info(get_trt_network(layer),
|
527 |
+
layer.name)
|
528 |
+
set_plugin_info(self.as_trt(), new_layer.name, plugin_info)
|
529 |
+
return wrapped_layer
|
530 |
+
|
531 |
+
def register_layer(self, layer, index=None):
|
532 |
+
if index is None:
|
533 |
+
index = self.num_layers - 1
|
534 |
+
assert self._trt.get_layer(index).name == layer.name
|
535 |
+
to_base_class_layer(layer)
|
536 |
+
for i in range(layer.num_outputs):
|
537 |
+
output = layer.get_output(i)
|
538 |
+
self._register_tensor(output)
|
539 |
+
wrapped_layer = Layer.from_trt(self, layer, index)
|
540 |
+
assert layer.name not in self._layers
|
541 |
+
self._layers[layer.name] = wrapped_layer
|
542 |
+
to_subclass_layer(layer)
|
543 |
+
return wrapped_layer
|
544 |
+
|
545 |
+
def get_runner(
|
546 |
+
self,
|
547 |
+
shapes=None,
|
548 |
+
values=None,
|
549 |
+
profile=None,
|
550 |
+
timing_cache=None,
|
551 |
+
opt_level=None,
|
552 |
+
) -> GraphRunner:
|
553 |
+
shapes = shapes or {}
|
554 |
+
values = values or {}
|
555 |
+
inputs = {}
|
556 |
+
outputs = {}
|
557 |
+
for input in self.inputs:
|
558 |
+
if input is not None:
|
559 |
+
value = values.get(input.name)
|
560 |
+
if value is None:
|
561 |
+
value = input.value
|
562 |
+
if value is not None:
|
563 |
+
if not isinstance(value, torch.Tensor):
|
564 |
+
value = torch.tensor(
|
565 |
+
value,
|
566 |
+
dtype=trt_dtype_to_torch(input.dtype),
|
567 |
+
device='cpu',
|
568 |
+
)
|
569 |
+
inputs[input.name] = value
|
570 |
+
else:
|
571 |
+
shape = shapes.get(input.name)
|
572 |
+
if shape is None:
|
573 |
+
shape = input.shape
|
574 |
+
assert shape is not None
|
575 |
+
inputs[input.name] = torch.empty(
|
576 |
+
tuple(shape),
|
577 |
+
dtype=trt_dtype_to_torch(input.dtype),
|
578 |
+
device=torch.cuda.current_device(),
|
579 |
+
)
|
580 |
+
if torch.is_floating_point(inputs[input.name]):
|
581 |
+
inputs[input.name].normal_()
|
582 |
+
# inputs[input.name][:] = random.choice([2, 3, 5, 7])
|
583 |
+
for output in self.outputs:
|
584 |
+
if output.as_trt().is_shape_tensor:
|
585 |
+
continue
|
586 |
+
if output.name in self._io_buffer_mapping:
|
587 |
+
input_name = self._io_buffer_mapping[output.name]
|
588 |
+
if input_name in inputs:
|
589 |
+
outputs[output.name] = inputs[input_name]
|
590 |
+
continue
|
591 |
+
value = values.get(output.name)
|
592 |
+
if value is not None and isinstance(value, torch.Tensor):
|
593 |
+
outputs[output.name] = value
|
594 |
+
else:
|
595 |
+
shape = shapes.get(output.name)
|
596 |
+
if shape is None:
|
597 |
+
shape = output.shape
|
598 |
+
assert shape is not None
|
599 |
+
outputs[output.name] = torch.empty(
|
600 |
+
tuple(shape),
|
601 |
+
dtype=trt_dtype_to_torch(output.dtype),
|
602 |
+
device=torch.cuda.current_device(),
|
603 |
+
)
|
604 |
+
network = self.as_trt()
|
605 |
+
config = network.builder.create_builder_config()
|
606 |
+
if opt_level is not None:
|
607 |
+
config.builder_optimization_level = opt_level
|
608 |
+
config.flags = get_builder_flags()
|
609 |
+
profile = profile or network.builder.create_optimization_profile()
|
610 |
+
profile_index = config.add_optimization_profile(profile)
|
611 |
+
if timing_cache is not None:
|
612 |
+
config.set_timing_cache(timing_cache, ignore_mismatch=False)
|
613 |
+
plan = network.builder.build_serialized_network(network, config)
|
614 |
+
if plan is None:
|
615 |
+
logger.error('Engine building failed, please check the error log.')
|
616 |
+
session = Session.from_serialized_engine(plan)
|
617 |
+
stream = torch.cuda.current_stream()
|
618 |
+
cuda_stream = stream.cuda_stream
|
619 |
+
context = session.context
|
620 |
+
context.set_optimization_profile_async(profile_index, cuda_stream)
|
621 |
+
runner = GraphRunner(session, inputs, outputs, stream)
|
622 |
+
return runner
|
623 |
+
|
624 |
+
def run(
|
625 |
+
self,
|
626 |
+
shapes=None,
|
627 |
+
values=None,
|
628 |
+
profile=None,
|
629 |
+
timing_cache=None,
|
630 |
+
opt_level=None,
|
631 |
+
):
|
632 |
+
return self.get_runner(
|
633 |
+
shapes,
|
634 |
+
values,
|
635 |
+
profile,
|
636 |
+
timing_cache,
|
637 |
+
opt_level,
|
638 |
+
).run()
|
639 |
+
|
640 |
+
def duplicate_graph(self):
|
641 |
+
graph = PipelineGraph.create_graph()
|
642 |
+
network = self.as_trt()
|
643 |
+
for i in range(network.num_inputs):
|
644 |
+
input = network.get_input(i)
|
645 |
+
graph.add_input(input)
|
646 |
+
sorted_layer_ids = get_sorted_layer_ids(network)
|
647 |
+
for i in sorted_layer_ids:
|
648 |
+
layer = network.get_layer(i)
|
649 |
+
graph.add_layer(layer)
|
650 |
+
for i in range(network.num_outputs):
|
651 |
+
output = network.get_output(i)
|
652 |
+
if output.is_shape_tensor:
|
653 |
+
graph.add_output_shape(output)
|
654 |
+
else:
|
655 |
+
graph.add_output(output)
|
656 |
+
return graph
|
657 |
+
|
658 |
+
@staticmethod
|
659 |
+
def from_trt(trt_network):
|
660 |
+
graph = PipelineGraph()
|
661 |
+
graph._trt = trt_network
|
662 |
+
|
663 |
+
# construct inputs and tensors
|
664 |
+
for i in range(trt_network.num_inputs):
|
665 |
+
trt_input = trt_network.get_input(i)
|
666 |
+
tensor = Tensor.from_trt(graph, trt_input)
|
667 |
+
tensor.graph_input_index = i
|
668 |
+
graph._tensors[tensor.name] = tensor
|
669 |
+
graph._inputs[tensor.name] = tensor
|
670 |
+
for i in range(trt_network.num_layers):
|
671 |
+
trt_layer = trt_network.get_layer(i)
|
672 |
+
for i in range(trt_layer.num_outputs):
|
673 |
+
trt_output = trt_layer.get_output(i)
|
674 |
+
tensor = Tensor.from_trt(graph, trt_output)
|
675 |
+
graph._tensors[tensor.name] = tensor
|
676 |
+
|
677 |
+
# construct layers and outputs
|
678 |
+
for i in range(trt_network.num_layers):
|
679 |
+
layer = Layer.from_trt(graph, trt_network.get_layer(i), i)
|
680 |
+
graph._layers[layer.name] = layer
|
681 |
+
for i in range(trt_network.num_outputs):
|
682 |
+
tensor_name = trt_network.get_output(i).name
|
683 |
+
output_tensor = graph._tensors[tensor_name]
|
684 |
+
output_tensor.graph_output_index = i
|
685 |
+
graph._outputs[tensor_name] = output_tensor
|
686 |
+
|
687 |
+
return graph
|
688 |
+
|
689 |
+
@staticmethod
|
690 |
+
def from_network(network: Network, builder_config):
|
691 |
+
builder_flags = builder_config.trt_builder_config.flags
|
692 |
+
with current_flags(builder_flags, network.strongly_typed):
|
693 |
+
graph = PipelineGraph.from_trt(network.trt_network)
|
694 |
+
graph.infer_shapes(network._generate_optimization_profiles()[-1])
|
695 |
+
return graph
|
696 |
+
|
697 |
+
def assign_shapes(self, shape_info=None, is_partial=False):
|
698 |
+
if shape_info is None:
|
699 |
+
for tensor in self.tensors:
|
700 |
+
tensor.shape = tensor.raw_shape
|
701 |
+
return
|
702 |
+
for tensor in self.tensors:
|
703 |
+
if tensor.name in shape_info.shapes:
|
704 |
+
tensor.shape = shape_info.shapes[tensor.name]
|
705 |
+
elif not is_partial:
|
706 |
+
raise ValueError(f"Cannot find shape for tensor: {tensor.name}")
|
707 |
+
if shape_info.max_shapes is not None:
|
708 |
+
if tensor.name in shape_info.max_shapes:
|
709 |
+
tensor.max_shape = shape_info.max_shapes[tensor.name]
|
710 |
+
elif not is_partial:
|
711 |
+
raise ValueError(
|
712 |
+
f"Cannot find max shape for tensor: {tensor.name}")
|
713 |
+
if tensor.name in shape_info.values:
|
714 |
+
tensor.value = shape_info.values[tensor.name]
|
715 |
+
for layer in self.layers:
|
716 |
+
if layer.name in shape_info.shape_layers:
|
717 |
+
layer._is_shape_io = True
|
718 |
+
|
719 |
+
def infer_shapes(self, profile=None):
|
720 |
+
from .shape_info import get_shape_info
|
721 |
+
|
722 |
+
shape_info = get_shape_info(self._trt, profile)
|
723 |
+
self.assign_shapes(shape_info)
|
724 |
+
|
725 |
+
def as_trt(self) -> trt.INetworkDefinition:
|
726 |
+
return self._trt
|
727 |
+
|
728 |
+
def get_input(self, name: str) -> Tensor:
|
729 |
+
return self._inputs.get(name)
|
730 |
+
|
731 |
+
def is_input(self, name: str) -> bool:
|
732 |
+
return name in self._inputs
|
733 |
+
|
734 |
+
@property
|
735 |
+
def inputs(self) -> List[Tensor]:
|
736 |
+
return [*self._inputs.values()]
|
737 |
+
|
738 |
+
@property
|
739 |
+
def num_inputs(self) -> int:
|
740 |
+
return self._trt.num_inputs
|
741 |
+
|
742 |
+
def get_output(self, name: str) -> Tensor:
|
743 |
+
return self._outputs.get(name)
|
744 |
+
|
745 |
+
def is_output(self, name: str) -> bool:
|
746 |
+
return name in self._outputs
|
747 |
+
|
748 |
+
@property
|
749 |
+
def outputs(self) -> List[Tensor]:
|
750 |
+
return [*self._outputs.values()]
|
751 |
+
|
752 |
+
@property
|
753 |
+
def num_outputs(self) -> int:
|
754 |
+
return self._trt.num_outputs
|
755 |
+
|
756 |
+
def get_tensor(self, name: str) -> Tensor:
|
757 |
+
return self._tensors.get(name)
|
758 |
+
|
759 |
+
@property
|
760 |
+
def tensors(self) -> List[Tensor]:
|
761 |
+
return [*self._tensors.values()]
|
762 |
+
|
763 |
+
def get_layer(self, name: str) -> Layer:
|
764 |
+
return self._layers.get(name)
|
765 |
+
|
766 |
+
@property
|
767 |
+
def layers(self) -> List[Layer]:
|
768 |
+
return [*self._layers.values()]
|
769 |
+
|
770 |
+
@property
|
771 |
+
def sorted_layers(self) -> List[Layer]:
|
772 |
+
sorted_layer_ids = get_sorted_layer_ids(self.as_trt())
|
773 |
+
return [
|
774 |
+
self.get_layer(self.as_trt().get_layer(layer_id).name)
|
775 |
+
for layer_id in sorted_layer_ids
|
776 |
+
]
|
777 |
+
|
778 |
+
@property
|
779 |
+
def num_layers(self) -> int:
|
780 |
+
return self._trt.num_layers
|
781 |
+
|
782 |
+
def to_dot(self,
|
783 |
+
path=None,
|
784 |
+
per_device=False,
|
785 |
+
per_block=False,
|
786 |
+
ignore_shape_io=False,
|
787 |
+
no_style=False,
|
788 |
+
extra_attrs=None) -> Optional[str]:
|
789 |
+
'''
|
790 |
+
Get a graphviz representation of the graph.
|
791 |
+
|
792 |
+
Parameters:
|
793 |
+
path: the path to save the graphviz file, if not provided, will return the graphviz source code
|
794 |
+
'''
|
795 |
+
try:
|
796 |
+
import graphviz
|
797 |
+
except ImportError:
|
798 |
+
logger.error(
|
799 |
+
"Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()"
|
800 |
+
)
|
801 |
+
return
|
802 |
+
|
803 |
+
extra_attrs = extra_attrs or []
|
804 |
+
|
805 |
+
graph = graphviz.Digraph()
|
806 |
+
input_block_graph = graphviz.Digraph(name='cluster_inputs')
|
807 |
+
output_block_graph = graphviz.Digraph(name='cluster_outputs')
|
808 |
+
device_graphs = {}
|
809 |
+
block_graphs = {}
|
810 |
+
block_graph_mapping = []
|
811 |
+
tensor_names = set()
|
812 |
+
layer_names = set()
|
813 |
+
|
814 |
+
common_style = dict(fontname='Arial', )
|
815 |
+
node_style = dict(
|
816 |
+
**common_style,
|
817 |
+
style='rounded,filled,bold',
|
818 |
+
)
|
819 |
+
tensor_style = dict(
|
820 |
+
**node_style,
|
821 |
+
shape='ellipse',
|
822 |
+
fillcolor='white',
|
823 |
+
)
|
824 |
+
input_tensor_style = {**tensor_style, 'fillcolor': 'green'}
|
825 |
+
output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'}
|
826 |
+
layer_style = dict(
|
827 |
+
**node_style,
|
828 |
+
shape='box',
|
829 |
+
fillcolor='white',
|
830 |
+
)
|
831 |
+
shape_layer_style = {**layer_style, 'fillcolor': 'grey'}
|
832 |
+
helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'}
|
833 |
+
graph_style = dict(
|
834 |
+
**common_style,
|
835 |
+
style='rounded',
|
836 |
+
penwidth='5',
|
837 |
+
fontsize='28',
|
838 |
+
)
|
839 |
+
device_graph_style = dict(
|
840 |
+
**graph_style,
|
841 |
+
color='cornflowerblue',
|
842 |
+
)
|
843 |
+
block_graph_style = dict(
|
844 |
+
**graph_style,
|
845 |
+
color='darkcyan',
|
846 |
+
)
|
847 |
+
input_block_style = dict(
|
848 |
+
**graph_style,
|
849 |
+
color='green',
|
850 |
+
)
|
851 |
+
output_block_style = dict(
|
852 |
+
**graph_style,
|
853 |
+
color='lightgreen',
|
854 |
+
)
|
855 |
+
if no_style:
|
856 |
+
device_graph_style = {}
|
857 |
+
block_graph_style = {}
|
858 |
+
input_block_style = {}
|
859 |
+
output_block_style = {}
|
860 |
+
input_block_graph.attr(label='inputs', **input_block_style)
|
861 |
+
output_block_graph.attr(label='outputs', **output_block_style)
|
862 |
+
|
863 |
+
def get_tensor_labels(tensor):
|
864 |
+
labels = []
|
865 |
+
if tensor.value is not None:
|
866 |
+
labels.append(f"value={tensor.value}")
|
867 |
+
else:
|
868 |
+
labels.append(f"dtype={tensor.dtype.name}{tensor.shape}")
|
869 |
+
for attr_name in extra_attrs:
|
870 |
+
if attr_name in tensor.attrs:
|
871 |
+
labels.append(f"{attr_name}={tensor.attrs[attr_name]}")
|
872 |
+
return labels
|
873 |
+
|
874 |
+
def get_device_graph(name):
|
875 |
+
if per_device and name.startswith('device'):
|
876 |
+
device_name = name.split('_')[0]
|
877 |
+
if device_name not in device_graphs:
|
878 |
+
device_graph = graphviz.Digraph(name='cluster_' +
|
879 |
+
device_name)
|
880 |
+
device_graph.attr(label=device_name, **device_graph_style)
|
881 |
+
device_graphs[device_name] = device_graph
|
882 |
+
return device_graphs[device_name]
|
883 |
+
return None
|
884 |
+
|
885 |
+
def get_block_graph(layer, current_graph):
|
886 |
+
if per_block and 'block_id' in layer.attrs:
|
887 |
+
block_label = f"block{layer.attrs['block_id']}"
|
888 |
+
if current_graph.name is not None:
|
889 |
+
graph_label = current_graph.name[len('cluster_'):]
|
890 |
+
else:
|
891 |
+
graph_label = ''
|
892 |
+
block_name = f"{graph_label}{block_label}"
|
893 |
+
if block_name not in block_graphs:
|
894 |
+
block_graph = graphviz.Digraph(name='cluster_' + block_name)
|
895 |
+
block_graph.attr(label=block_label, **block_graph_style)
|
896 |
+
block_graphs[block_name] = block_graph
|
897 |
+
block_graph_mapping.append((current_graph, block_graph))
|
898 |
+
return block_graphs[block_name]
|
899 |
+
return current_graph
|
900 |
+
|
901 |
+
for name, tensor in self._tensors.items():
|
902 |
+
style = tensor_style
|
903 |
+
if tensor.is_graph_input:
|
904 |
+
style = input_tensor_style
|
905 |
+
current_graph = input_block_graph
|
906 |
+
elif tensor.is_graph_output:
|
907 |
+
style = output_tensor_style
|
908 |
+
current_graph = output_block_graph
|
909 |
+
elif tensor.producer.num_outputs == 1:
|
910 |
+
continue
|
911 |
+
else:
|
912 |
+
current_graph = get_device_graph(name) or graph
|
913 |
+
current_graph = get_block_graph(tensor.producer, current_graph)
|
914 |
+
if no_style:
|
915 |
+
style = {}
|
916 |
+
labels = [name, *get_tensor_labels(tensor)]
|
917 |
+
content = "\n".join(labels)
|
918 |
+
current_graph.node(name, content, **style)
|
919 |
+
tensor_names.add(name)
|
920 |
+
|
921 |
+
for layer in self.sorted_layers:
|
922 |
+
name = layer.name
|
923 |
+
|
924 |
+
style = layer_style
|
925 |
+
if layer.is_shape_io:
|
926 |
+
if ignore_shape_io:
|
927 |
+
continue
|
928 |
+
style = shape_layer_style
|
929 |
+
elif layer.attrs.get("role", None) == "helper":
|
930 |
+
style = helper_layer_style
|
931 |
+
fillcolor = None
|
932 |
+
plugin_type = None
|
933 |
+
if layer.type == trt.LayerType.PLUGIN_V2:
|
934 |
+
fillcolor = 'yellow'
|
935 |
+
layer.to_subclass()
|
936 |
+
plugin_type = layer.as_trt().plugin.plugin_type
|
937 |
+
layer.to_base_class()
|
938 |
+
if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm':
|
939 |
+
fillcolor = 'orange'
|
940 |
+
if fillcolor is not None:
|
941 |
+
style = {**style, 'fillcolor': fillcolor}
|
942 |
+
if no_style:
|
943 |
+
style = {}
|
944 |
+
|
945 |
+
layer_attrs = {}
|
946 |
+
layer_type = layer.type
|
947 |
+
layer.to_subclass()
|
948 |
+
if layer_type == trt.LayerType.CONSTANT:
|
949 |
+
if not layer.is_shape_io:
|
950 |
+
if trt.volume(layer.get_output(0).shape) <= 8:
|
951 |
+
weights = layer.as_trt().weights
|
952 |
+
if isinstance(weights, trt.Weights):
|
953 |
+
weights = weights.numpy()
|
954 |
+
value = np.array2string(
|
955 |
+
weights,
|
956 |
+
formatter={'float_kind': lambda x: f"{x:.2e}"})
|
957 |
+
layer_attrs['value'] = value
|
958 |
+
elif layer_type == trt.LayerType.SHUFFLE:
|
959 |
+
for attr_name in ['first_transpose', 'second_transpose']:
|
960 |
+
attr_value = getattr(layer.as_trt(), attr_name)
|
961 |
+
if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7):
|
962 |
+
tensor = layer.get_input(
|
963 |
+
0
|
964 |
+
) if attr_name == 'first_transpose' else layer.get_output(
|
965 |
+
0)
|
966 |
+
layer_attrs[attr_name] = tuple(
|
967 |
+
attr_value)[:len(tensor.shape)]
|
968 |
+
if layer.num_inputs < 2:
|
969 |
+
attr_value = layer.as_trt().reshape_dims
|
970 |
+
layer_attrs['reshape_dims'] = attr_value
|
971 |
+
elif layer_type == trt.LayerType.SLICE:
|
972 |
+
if layer.num_inputs < 2 or layer.get_input(1) is None:
|
973 |
+
layer_attrs['start'] = layer.as_trt().start
|
974 |
+
if layer.num_inputs < 4 or layer.get_input(3) is None:
|
975 |
+
attr_value = layer.as_trt().stride
|
976 |
+
if attr_value != tuple(
|
977 |
+
[1] * len(layer.get_output(0).shape)):
|
978 |
+
layer_attrs['stride'] = attr_value
|
979 |
+
layer.to_base_class()
|
980 |
+
|
981 |
+
if layer.is_shape_io:
|
982 |
+
labels = [layer.type.name]
|
983 |
+
else:
|
984 |
+
labels = [name, layer.type.name]
|
985 |
+
for key, value in layer_attrs.items():
|
986 |
+
labels.append(f"{key}={value}")
|
987 |
+
for attr_name in extra_attrs:
|
988 |
+
if attr_name in layer.attrs:
|
989 |
+
labels.append(f"{attr_name}={layer.attrs[attr_name]}")
|
990 |
+
if layer.num_outputs == 1:
|
991 |
+
output = layer.get_output(0)
|
992 |
+
if output.name != f'{layer.name}_output_0':
|
993 |
+
labels.append(f"output={output.name}")
|
994 |
+
labels.extend(get_tensor_labels(output))
|
995 |
+
content = "\n".join(labels)
|
996 |
+
|
997 |
+
current_graph = get_device_graph(name) or graph
|
998 |
+
current_graph = get_block_graph(layer, current_graph)
|
999 |
+
current_graph.node(name, content, **style)
|
1000 |
+
layer_names.add(name)
|
1001 |
+
|
1002 |
+
for index, input in enumerate(layer.inputs):
|
1003 |
+
if input is not None:
|
1004 |
+
if input.is_graph_input or input.producer.num_outputs > 1:
|
1005 |
+
if input.name in tensor_names:
|
1006 |
+
graph.edge(input.name, name, str(index))
|
1007 |
+
else:
|
1008 |
+
if input.producer.name in layer_names:
|
1009 |
+
graph.edge(input.producer.name, name, str(index))
|
1010 |
+
if layer.num_outputs > 1 or (layer.num_outputs == 1 and
|
1011 |
+
layer.get_output(0).is_graph_output):
|
1012 |
+
for index, output in enumerate(layer.outputs):
|
1013 |
+
graph.edge(name, output.name, str(index))
|
1014 |
+
|
1015 |
+
graph.subgraph(input_block_graph)
|
1016 |
+
graph.subgraph(output_block_graph)
|
1017 |
+
for parent_graph, block_graph in block_graph_mapping:
|
1018 |
+
parent_graph.subgraph(block_graph)
|
1019 |
+
for device_graph in device_graphs.values():
|
1020 |
+
graph.subgraph(device_graph)
|
1021 |
+
|
1022 |
+
if not path:
|
1023 |
+
return graph.source
|
1024 |
+
graph.save(path)
|
1025 |
+
|
1026 |
+
@staticmethod
|
1027 |
+
def trt_to_dot(trt_network, path=None):
|
1028 |
+
graph = PipelineGraph.from_trt(trt_network)
|
1029 |
+
graph.assign_shapes()
|
1030 |
+
dot = graph.to_dot(no_style=True)
|
1031 |
+
if path is not None:
|
1032 |
+
with open(path, "w") as f:
|
1033 |
+
f.write(dot)
|
1034 |
+
else:
|
1035 |
+
return dot
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/runtime_profiling.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorrt as trt
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from tensorrt_llm.logger import logger
|
6 |
+
from tensorrt_llm.network import get_plugin_info
|
7 |
+
|
8 |
+
from .shape_info import get_per_layer_graph
|
9 |
+
from .utils import get_cache_key, get_trt_network, get_updated_plugin
|
10 |
+
|
11 |
+
|
12 |
+
class NvtxProfiler(object):
|
13 |
+
|
14 |
+
def __init__(self, nvtx_name, enable=True):
|
15 |
+
self.nvtx_name = nvtx_name
|
16 |
+
self.enable = enable
|
17 |
+
|
18 |
+
def __enter__(self):
|
19 |
+
if self.enable:
|
20 |
+
torch.cuda.nvtx.range_push(self.nvtx_name)
|
21 |
+
|
22 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
23 |
+
if self.enable:
|
24 |
+
torch.cuda.nvtx.range_pop()
|
25 |
+
|
26 |
+
|
27 |
+
class LayerProfiler(trt.IProfiler):
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
trt.IProfiler.__init__(self)
|
31 |
+
self.layer_count = 0
|
32 |
+
self.time = 0
|
33 |
+
|
34 |
+
def report_layer_time(self, layer_name, ms):
|
35 |
+
logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
|
36 |
+
self.time += ms
|
37 |
+
self.layer_count += 1
|
38 |
+
|
39 |
+
|
40 |
+
class RuntimeProfiler(object):
|
41 |
+
|
42 |
+
def __init__(self):
|
43 |
+
self.timing_cache = None
|
44 |
+
|
45 |
+
def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
|
46 |
+
is_plugin = layer.type == trt.LayerType.PLUGIN_V2
|
47 |
+
if is_plugin and len(layer_attrs) > 0:
|
48 |
+
plugin_info = get_plugin_info(
|
49 |
+
get_trt_network(layer),
|
50 |
+
layer.name,
|
51 |
+
)
|
52 |
+
new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
|
53 |
+
layer_attrs = {"plugin": new_plugin}
|
54 |
+
graph, output_mapping = get_per_layer_graph(layer, shapes, values,
|
55 |
+
layer_attrs)
|
56 |
+
graph._io_buffer_mapping = io_buffer_mapping
|
57 |
+
network = graph.as_trt()
|
58 |
+
if network.num_outputs > 0 and np.all([
|
59 |
+
network.get_output(i).is_shape_tensor
|
60 |
+
for i in range(network.num_outputs)
|
61 |
+
]):
|
62 |
+
return 0.0
|
63 |
+
for proxy_output, output in output_mapping.items():
|
64 |
+
shapes[proxy_output] = shapes[output]
|
65 |
+
if not self.timing_cache:
|
66 |
+
self.timing_cache = network.builder.create_builder_config(
|
67 |
+
).create_timing_cache(b"")
|
68 |
+
runner = graph.get_runner(
|
69 |
+
shapes,
|
70 |
+
values,
|
71 |
+
timing_cache=self.timing_cache,
|
72 |
+
)
|
73 |
+
context = runner.session.context
|
74 |
+
context.profiler = LayerProfiler()
|
75 |
+
runner.run()
|
76 |
+
profiler_time_first_run = context.profiler.time
|
77 |
+
runner.run()
|
78 |
+
return (context.profiler.time - profiler_time_first_run) * 1000.0
|
79 |
+
|
80 |
+
def runtime_profile(self, layer, layer_attrs, input_values, strategy,
|
81 |
+
device_mesh):
|
82 |
+
logger.debug(f"start to profile layer {layer.name}")
|
83 |
+
shapes = {}
|
84 |
+
values = {}
|
85 |
+
dtypes = {}
|
86 |
+
trt_layer = layer.as_trt()
|
87 |
+
|
88 |
+
sharding_sequences = ()
|
89 |
+
for i in range(layer.num_inputs):
|
90 |
+
input = trt_layer.get_input(i)
|
91 |
+
if input is not None:
|
92 |
+
shapes[input.name] = strategy.sharding_specs[
|
93 |
+
f'input{i}'].get_sharded_shape_per_device()
|
94 |
+
dtypes[input.name] = input.dtype
|
95 |
+
sharding_sequences += (str(
|
96 |
+
strategy.sharding_specs[f"input{i}"].sharding_sequence), )
|
97 |
+
if i in input_values:
|
98 |
+
values[input.name] = input_values[i]
|
99 |
+
else:
|
100 |
+
value = layer.get_input(i).value
|
101 |
+
if value is not None:
|
102 |
+
values[input.name] = value
|
103 |
+
else:
|
104 |
+
sharding_sequences += (None, )
|
105 |
+
|
106 |
+
for i in range(layer.num_outputs):
|
107 |
+
output = trt_layer.get_output(i)
|
108 |
+
if f'output{i}' in strategy.communication_actions:
|
109 |
+
shapes[output.name] = strategy.communication_actions[
|
110 |
+
f'output{i}'].sharding_spec.get_sharded_shape_per_device()
|
111 |
+
else:
|
112 |
+
shapes[output.name] = strategy.sharding_specs[
|
113 |
+
f'output{i}'].get_sharded_shape_per_device()
|
114 |
+
dtypes[output.name] = output.dtype
|
115 |
+
sharding_sequences += (str(
|
116 |
+
strategy.sharding_specs[f"output{i}"].sharding_sequence), )
|
117 |
+
data_key = get_cache_key(
|
118 |
+
trt_layer,
|
119 |
+
shapes,
|
120 |
+
values,
|
121 |
+
dtypes=dtypes,
|
122 |
+
updated_attrs=layer_attrs,
|
123 |
+
)
|
124 |
+
data_key += (sharding_sequences, )
|
125 |
+
elapsed_time = device_mesh.prof_database.query(
|
126 |
+
device_mesh.cluster_key,
|
127 |
+
data_key,
|
128 |
+
)
|
129 |
+
if elapsed_time:
|
130 |
+
logger.debug(
|
131 |
+
f'runtime profiling cache hit {data_key}: {elapsed_time} us')
|
132 |
+
return elapsed_time
|
133 |
+
with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
|
134 |
+
elapsed_time = self._profile(
|
135 |
+
layer.as_trt(),
|
136 |
+
layer_attrs,
|
137 |
+
shapes,
|
138 |
+
values,
|
139 |
+
layer.graph._io_buffer_mapping,
|
140 |
+
)
|
141 |
+
logger.debug(
|
142 |
+
f'runtime profiling cache miss {data_key}: {elapsed_time} us')
|
143 |
+
|
144 |
+
device_mesh.prof_database.update(
|
145 |
+
device_mesh.cluster_key,
|
146 |
+
data_key,
|
147 |
+
(elapsed_time, strategy.alpha_beta_cost),
|
148 |
+
)
|
149 |
+
|
150 |
+
return elapsed_time
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/shape_info.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Dict, List, Set
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tensorrt as trt
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from tensorrt_llm._common import _is_building
|
10 |
+
from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str,
|
11 |
+
trt_dtype_to_torch)
|
12 |
+
from tensorrt_llm.logger import logger
|
13 |
+
|
14 |
+
from .pipeline_graph import PipelineGraph
|
15 |
+
from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids,
|
16 |
+
set_trt_network, to_base_class_layer, to_subclass_layer,
|
17 |
+
to_trt_weights)
|
18 |
+
|
19 |
+
|
20 |
+
class ShapeType(Enum):
|
21 |
+
MIN = 0
|
22 |
+
OPT = 1
|
23 |
+
MAX = 2
|
24 |
+
|
25 |
+
|
26 |
+
_trt_to_type_dict = {
|
27 |
+
trt.int64: int,
|
28 |
+
trt.bool: bool,
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def get_shape_layers(trt_network):
|
33 |
+
shape_layers = set()
|
34 |
+
for i in range(trt_network.num_layers):
|
35 |
+
layer = trt_network.get_layer(i)
|
36 |
+
if (layer.num_inputs > 0 and np.all([
|
37 |
+
layer.get_input(j).is_shape_tensor
|
38 |
+
for j in range(layer.num_inputs)
|
39 |
+
if layer.get_input(j) is not None
|
40 |
+
])) or (layer.num_outputs > 0 and np.all([
|
41 |
+
layer.get_output(j).is_shape_tensor
|
42 |
+
for j in range(layer.num_outputs)
|
43 |
+
])):
|
44 |
+
shape_layers.add(layer.name)
|
45 |
+
return shape_layers
|
46 |
+
|
47 |
+
|
48 |
+
def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids):
|
49 |
+
layers = set()
|
50 |
+
shape_tensors = set()
|
51 |
+
for layer_id in reversed(sorted_layer_ids):
|
52 |
+
layer = trt_network.get_layer(layer_id)
|
53 |
+
in_shape_network = False
|
54 |
+
if layer.name in shape_layers:
|
55 |
+
in_shape_network = True
|
56 |
+
else:
|
57 |
+
for j in range(layer.num_outputs):
|
58 |
+
output = layer.get_output(j)
|
59 |
+
if output.name in shape_tensors:
|
60 |
+
in_shape_network = True
|
61 |
+
break
|
62 |
+
if in_shape_network:
|
63 |
+
layers.add(layer.name)
|
64 |
+
for j in range(layer.num_inputs):
|
65 |
+
input = layer.get_input(j)
|
66 |
+
if input is not None:
|
67 |
+
shape_tensors.add(input.name)
|
68 |
+
return layers
|
69 |
+
|
70 |
+
|
71 |
+
def get_shape_network(trt_network,
|
72 |
+
shapes,
|
73 |
+
values,
|
74 |
+
sorted_layer_ids,
|
75 |
+
profile=None,
|
76 |
+
shape_type: ShapeType = ShapeType.OPT):
|
77 |
+
shape_layers = get_shape_layers(trt_network)
|
78 |
+
layers_in_shape_network = get_layers_in_shape_network(
|
79 |
+
trt_network, shape_layers, sorted_layer_ids)
|
80 |
+
shape_graph = PipelineGraph.create_graph()
|
81 |
+
shape_network = shape_graph.as_trt()
|
82 |
+
shape_builder = shape_network.builder
|
83 |
+
shape_profile = shape_builder.create_optimization_profile()
|
84 |
+
for i in range(trt_network.num_inputs):
|
85 |
+
input = trt_network.get_input(i)
|
86 |
+
shapes[input.name] = input.shape
|
87 |
+
new_input = shape_graph.add_input(input)
|
88 |
+
if profile is not None:
|
89 |
+
if -1 in input.shape:
|
90 |
+
shape = profile.get_shape(input.name)
|
91 |
+
shape = shape[shape_type.value]
|
92 |
+
shapes[input.name] = shape
|
93 |
+
new_input.raw_shape = shape
|
94 |
+
if input.is_shape_tensor:
|
95 |
+
shape_values = profile.get_shape_input(input.name)
|
96 |
+
value = shape_values[shape_type.value]
|
97 |
+
values[input.name] = value
|
98 |
+
shape_profile.set_shape_input(input.name, value, value, value)
|
99 |
+
output_mapping = {}
|
100 |
+
for layer_id in sorted_layer_ids:
|
101 |
+
layer = trt_network.get_layer(layer_id)
|
102 |
+
if layer.name in shape_layers:
|
103 |
+
new_layer = shape_graph.add_layer(layer)
|
104 |
+
for i in range(layer.num_outputs):
|
105 |
+
output = layer.get_output(i)
|
106 |
+
if output.dtype == trt.DataType.BOOL:
|
107 |
+
proxy_layer = shape_network.add_cast(
|
108 |
+
new_layer.as_trt().get_output(i),
|
109 |
+
trt.DataType.INT32,
|
110 |
+
)
|
111 |
+
proxy_output = proxy_layer.get_output(0)
|
112 |
+
shape_graph.register_layer(proxy_layer)
|
113 |
+
shape_graph.add_output_shape(proxy_output)
|
114 |
+
output_mapping[proxy_output.name] = (output.name,
|
115 |
+
output.dtype)
|
116 |
+
else:
|
117 |
+
shape_graph.add_output_shape(output)
|
118 |
+
elif layer.name in layers_in_shape_network:
|
119 |
+
if layer.type == trt.LayerType.CONSTANT:
|
120 |
+
shape_graph.add_input(layer.get_output(0))
|
121 |
+
else:
|
122 |
+
shape_graph.add_layer(layer)
|
123 |
+
return shape_network, shape_profile, shape_layers, output_mapping
|
124 |
+
|
125 |
+
|
126 |
+
def get_per_layer_graph(
|
127 |
+
layer,
|
128 |
+
shapes,
|
129 |
+
values,
|
130 |
+
updated_attrs=None,
|
131 |
+
is_shape_io: bool = None,
|
132 |
+
):
|
133 |
+
graph = PipelineGraph.create_graph()
|
134 |
+
network = graph.as_trt()
|
135 |
+
is_shape_layer = layer.num_inputs != 0
|
136 |
+
for i in range(layer.num_inputs):
|
137 |
+
input = layer.get_input(i)
|
138 |
+
if input is not None:
|
139 |
+
shape = shapes[input.name]
|
140 |
+
if (values.get(input.name) is not None
|
141 |
+
and not isinstance(values[input.name], torch.Tensor)):
|
142 |
+
value = values[input.name]
|
143 |
+
weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype))
|
144 |
+
weights = to_trt_weights(weights)
|
145 |
+
input_layer = network.add_constant(shape, weights)
|
146 |
+
new_input = input_layer.get_output(0)
|
147 |
+
new_input.name = input.name
|
148 |
+
graph.register_layer(input_layer)
|
149 |
+
elif graph.get_input(input.name) is None:
|
150 |
+
new_input = graph.add_input(input)
|
151 |
+
new_input.raw_shape = shapes[input.name]
|
152 |
+
is_shape_layer = False
|
153 |
+
new_layer = graph.add_layer(
|
154 |
+
layer,
|
155 |
+
updated_attrs=updated_attrs,
|
156 |
+
)
|
157 |
+
output_mapping = {}
|
158 |
+
if layer.type == trt.LayerType.SHAPE:
|
159 |
+
is_shape_layer = True
|
160 |
+
if layer.num_inputs == 0:
|
161 |
+
is_shape_layer = False
|
162 |
+
if is_shape_io is not None:
|
163 |
+
is_shape_layer = is_shape_io
|
164 |
+
for i in range(layer.num_outputs):
|
165 |
+
output = layer.get_output(i)
|
166 |
+
value = values.get(output.name)
|
167 |
+
if value is not None and isinstance(value, torch.Tensor):
|
168 |
+
is_output_shape = False
|
169 |
+
elif is_shape_layer:
|
170 |
+
is_output_shape = True
|
171 |
+
else:
|
172 |
+
is_output_shape = False
|
173 |
+
if is_output_shape:
|
174 |
+
if output.dtype == trt.DataType.BOOL:
|
175 |
+
proxy_layer = network.add_cast(
|
176 |
+
new_layer.as_trt().get_output(i),
|
177 |
+
trt.DataType.INT32,
|
178 |
+
)
|
179 |
+
proxy_output = proxy_layer.get_output(0)
|
180 |
+
graph.register_layer(proxy_layer)
|
181 |
+
output_mapping[proxy_output.name] = (output.name, output.dtype)
|
182 |
+
output = proxy_output
|
183 |
+
graph.add_output_shape(output)
|
184 |
+
else:
|
185 |
+
graph.add_output(output)
|
186 |
+
return graph, output_mapping
|
187 |
+
|
188 |
+
|
189 |
+
@_is_building
|
190 |
+
def infer_shapes(network, shapes, values, profile=None):
|
191 |
+
if network.num_outputs == 0:
|
192 |
+
return
|
193 |
+
builder = network.builder
|
194 |
+
config = builder.create_builder_config()
|
195 |
+
config.builder_optimization_level = 0
|
196 |
+
config.flags = get_builder_flags()
|
197 |
+
profile = profile or builder.create_optimization_profile()
|
198 |
+
config.add_optimization_profile(profile)
|
199 |
+
plan = builder.build_serialized_network(network, config)
|
200 |
+
if plan is None:
|
201 |
+
raise RuntimeError(
|
202 |
+
'Engine building failed when inferring shapes, please check the error log.'
|
203 |
+
)
|
204 |
+
runtime = trt.Runtime(logger.trt_logger)
|
205 |
+
engine = runtime.deserialize_cuda_engine(plan)
|
206 |
+
context = engine.create_execution_context()
|
207 |
+
for i in range(network.num_inputs):
|
208 |
+
input = network.get_input(i)
|
209 |
+
if input.is_shape_tensor:
|
210 |
+
value = values[input.name]
|
211 |
+
context.set_shape_input(engine[input.name], value)
|
212 |
+
for i in range(network.num_outputs):
|
213 |
+
output = network.get_output(i)
|
214 |
+
shape = context.get_tensor_shape(output.name)
|
215 |
+
shapes[output.name] = shape
|
216 |
+
if output.is_shape_tensor:
|
217 |
+
if shape == [0]:
|
218 |
+
values[output.name] = []
|
219 |
+
else:
|
220 |
+
if shape == []:
|
221 |
+
shape = [1]
|
222 |
+
value = torch.empty(
|
223 |
+
list(shape),
|
224 |
+
dtype=trt_dtype_to_torch(output.dtype),
|
225 |
+
device="cpu",
|
226 |
+
)
|
227 |
+
values[output.name] = value
|
228 |
+
context.set_tensor_address(output.name, value.data_ptr())
|
229 |
+
context.infer_shapes()
|
230 |
+
assert context.all_binding_shapes_specified
|
231 |
+
for i in range(network.num_outputs):
|
232 |
+
output = network.get_output(i)
|
233 |
+
if isinstance(values.get(output.name), torch.Tensor):
|
234 |
+
values[output.name] = values[output.name].tolist()
|
235 |
+
|
236 |
+
|
237 |
+
@dataclass
|
238 |
+
class ShapeInfo:
|
239 |
+
shapes: Dict[str, trt.Dims]
|
240 |
+
values: Dict[str, List[int]]
|
241 |
+
shape_layers: Set[str]
|
242 |
+
max_shapes: Dict[str, trt.Dims] = None
|
243 |
+
|
244 |
+
|
245 |
+
def set_constant_value(layer, values):
|
246 |
+
to_subclass_layer(layer)
|
247 |
+
output_name = layer.get_output(0).name
|
248 |
+
weights = layer.weights
|
249 |
+
if isinstance(weights, trt.Weights):
|
250 |
+
weights = weights.numpy()
|
251 |
+
values[output_name] = list(weights)
|
252 |
+
to_base_class_layer(layer)
|
253 |
+
|
254 |
+
|
255 |
+
def infer_per_layer_shapes(
|
256 |
+
layer: trt.ILayer,
|
257 |
+
shapes,
|
258 |
+
values,
|
259 |
+
cache=None,
|
260 |
+
is_shape_io=False,
|
261 |
+
):
|
262 |
+
if layer.type == trt.LayerType.CONSTANT:
|
263 |
+
to_subclass_layer(layer)
|
264 |
+
output_name = layer.get_output(0).name
|
265 |
+
shape = layer.shape
|
266 |
+
shapes[output_name] = shape
|
267 |
+
if is_shape_io:
|
268 |
+
set_constant_value(layer, values)
|
269 |
+
to_base_class_layer(layer)
|
270 |
+
return
|
271 |
+
elif layer.type == trt.LayerType.SHAPE:
|
272 |
+
input_name = layer.get_input(0).name
|
273 |
+
output_name = layer.get_output(0).name
|
274 |
+
shape = [*shapes[input_name]]
|
275 |
+
shapes[output_name] = trt.Dims([len(shape)])
|
276 |
+
values[output_name] = shape
|
277 |
+
return
|
278 |
+
if cache is not None:
|
279 |
+
cache_key = get_cache_key(layer, shapes, values)
|
280 |
+
if cache_key in cache:
|
281 |
+
output_shapes, output_values = cache[cache_key]
|
282 |
+
for i in range(layer.num_outputs):
|
283 |
+
output = layer.get_output(i)
|
284 |
+
shapes[output.name] = output_shapes[i]
|
285 |
+
if output_values[i] is not None:
|
286 |
+
values[output.name] = output_values[i]
|
287 |
+
return
|
288 |
+
graph, output_mapping = get_per_layer_graph(layer, shapes, values)
|
289 |
+
dtypes = [
|
290 |
+
trt_dtype_to_str(layer.get_input(i).dtype)
|
291 |
+
for i in range(layer.num_inputs)
|
292 |
+
]
|
293 |
+
layer_info = (f"type={cache_key[0]}, "
|
294 |
+
f"attrs={dict(cache_key[1])}, "
|
295 |
+
f"dtypes={dtypes}, "
|
296 |
+
f"shapes={list(cache_key[2])}, "
|
297 |
+
f"values={list(cache_key[3])}")
|
298 |
+
logger.debug(f"infer shapes for layer {layer.name} ({layer_info})")
|
299 |
+
try:
|
300 |
+
infer_shapes(graph.as_trt(), shapes, values)
|
301 |
+
except RuntimeError as e:
|
302 |
+
raise RuntimeError(
|
303 |
+
f"infer shapes failed for layer {layer.name} ({layer_info})") from e
|
304 |
+
for proxy_output, (output, dtype) in output_mapping.items():
|
305 |
+
shapes[output] = shapes[proxy_output]
|
306 |
+
del shapes[proxy_output]
|
307 |
+
if proxy_output in values:
|
308 |
+
values[output] = [
|
309 |
+
*map(_trt_to_type_dict[dtype], values[proxy_output])
|
310 |
+
]
|
311 |
+
del values[proxy_output]
|
312 |
+
if cache is not None:
|
313 |
+
logger.debug(
|
314 |
+
f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}"
|
315 |
+
)
|
316 |
+
output_shapes = []
|
317 |
+
output_values = []
|
318 |
+
for i in range(layer.num_outputs):
|
319 |
+
output = layer.get_output(i)
|
320 |
+
output_shapes.append(shapes[output.name])
|
321 |
+
output_values.append(values.get(output.name))
|
322 |
+
cache[cache_key] = (output_shapes, output_values)
|
323 |
+
|
324 |
+
|
325 |
+
def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT):
|
326 |
+
shapes = {}
|
327 |
+
values = {}
|
328 |
+
sorted_layer_ids = get_sorted_layer_ids(trt_network)
|
329 |
+
infer_shape_layers = False
|
330 |
+
|
331 |
+
shape_network, shape_profile, shape_layers, output_mapping = get_shape_network(
|
332 |
+
trt_network,
|
333 |
+
shapes,
|
334 |
+
values,
|
335 |
+
sorted_layer_ids,
|
336 |
+
profile=profile,
|
337 |
+
shape_type=shape_type)
|
338 |
+
try:
|
339 |
+
infer_shapes(shape_network, shapes, values, shape_profile)
|
340 |
+
for proxy_output, (output, dtype) in output_mapping.items():
|
341 |
+
shapes[output] = shapes[proxy_output]
|
342 |
+
values[output] = [
|
343 |
+
*map(_trt_to_type_dict[dtype], values[proxy_output])
|
344 |
+
]
|
345 |
+
del shapes[proxy_output]
|
346 |
+
del values[proxy_output]
|
347 |
+
except RuntimeError:
|
348 |
+
infer_shape_layers = True
|
349 |
+
|
350 |
+
cache = {}
|
351 |
+
for layer_id in sorted_layer_ids:
|
352 |
+
layer = trt_network.get_layer(layer_id)
|
353 |
+
is_shape_io = layer.name in shape_layers
|
354 |
+
if is_shape_io and not infer_shape_layers:
|
355 |
+
continue
|
356 |
+
set_trt_network(layer, trt_network)
|
357 |
+
infer_per_layer_shapes(layer,
|
358 |
+
shapes,
|
359 |
+
values,
|
360 |
+
cache,
|
361 |
+
is_shape_io=is_shape_io)
|
362 |
+
return ShapeInfo(shapes, values, shape_layers)
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/simplifier.py
ADDED
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Dict, List, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from tensorrt_llm.network import Network
|
10 |
+
|
11 |
+
from .config import AutoParallelConfig
|
12 |
+
from .device_mesh import PhysicalDeviceMesh
|
13 |
+
from .pipeline_graph import PipelineGraph
|
14 |
+
from .shape_info import ShapeInfo, ShapeType, get_shape_info
|
15 |
+
from .tensor_parallel.p2p_node import P2PType
|
16 |
+
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
|
17 |
+
|
18 |
+
|
19 |
+
class StageType(Enum):
|
20 |
+
START = 0
|
21 |
+
BLOCK = 1
|
22 |
+
END = 2
|
23 |
+
|
24 |
+
|
25 |
+
class BuildingBlock:
|
26 |
+
|
27 |
+
def __init__(self, graph, layer_range) -> None:
|
28 |
+
self.graph = graph
|
29 |
+
self.layer_range = layer_range
|
30 |
+
self.network = graph.as_trt()
|
31 |
+
self.owned_inputs = {}
|
32 |
+
self.is_edges_collected = False
|
33 |
+
self.intra_edges = []
|
34 |
+
self.src_inter_edges = []
|
35 |
+
self.dst_inter_edges = []
|
36 |
+
self.relative_src_inter_edges = []
|
37 |
+
self.relative_dst_inter_edges = []
|
38 |
+
self.relative_inter_edges = set()
|
39 |
+
self.edge_hash = None
|
40 |
+
self.outputs = None
|
41 |
+
self.type_id = -1
|
42 |
+
self.block_id = -1
|
43 |
+
self.p2p_type = None
|
44 |
+
self.is_superset = False
|
45 |
+
self.is_subset = False
|
46 |
+
self.sorted_layer_ids = []
|
47 |
+
|
48 |
+
def collect_edges(self):
|
49 |
+
if self.is_edges_collected:
|
50 |
+
return
|
51 |
+
for layer_index in self.layer_range:
|
52 |
+
trt_layer = self.network.get_layer(layer_index)
|
53 |
+
layer = self.graph.get_layer(trt_layer.name)
|
54 |
+
layer_offset = layer.index - self.layer_range.start
|
55 |
+
for input_index, input in enumerate(layer.inputs):
|
56 |
+
if input is not None:
|
57 |
+
if input.is_graph_input:
|
58 |
+
is_owned = input.graph_input_index in self.owned_inputs
|
59 |
+
if not is_owned and np.all([
|
60 |
+
layer.index in self.layer_range or np.all([
|
61 |
+
output.as_trt().is_shape_tensor
|
62 |
+
for output in layer.outputs
|
63 |
+
]) for layer, _ in input.consumers
|
64 |
+
]):
|
65 |
+
self.owned_inputs[input.graph_input_index] = len(
|
66 |
+
self.owned_inputs)
|
67 |
+
is_owned = True
|
68 |
+
if is_owned:
|
69 |
+
self.intra_edges.append(
|
70 |
+
(-1, self.owned_inputs[input.graph_input_index],
|
71 |
+
layer_offset, input_index))
|
72 |
+
else:
|
73 |
+
self.dst_inter_edges.append(
|
74 |
+
(-1, input.graph_input_index, layer_offset,
|
75 |
+
input_index))
|
76 |
+
else:
|
77 |
+
src_layer_index = input.producer.index
|
78 |
+
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
|
79 |
+
self.dst_inter_edges.append(
|
80 |
+
(src_layer_index, input.output_index,
|
81 |
+
layer_offset, input_index))
|
82 |
+
else:
|
83 |
+
src_layer_offset = src_layer_index - self.layer_range.start
|
84 |
+
self.intra_edges.append(
|
85 |
+
(src_layer_offset, input.output_index,
|
86 |
+
layer_offset, input_index))
|
87 |
+
for output_index, output in enumerate(layer.outputs):
|
88 |
+
for dst_layer, dst_input_index in output.consumers:
|
89 |
+
dst_layer_index = dst_layer.index
|
90 |
+
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
|
91 |
+
self.src_inter_edges.append(
|
92 |
+
(layer_offset, output_index, dst_layer_index,
|
93 |
+
dst_input_index))
|
94 |
+
self.edge_hash = tuple(self.intra_edges)
|
95 |
+
self.outputs = sorted(
|
96 |
+
set((edge[0], edge[1]) for edge in self.src_inter_edges))
|
97 |
+
self.is_edges_collected = True
|
98 |
+
|
99 |
+
def collect_relative_inter_edges(self, layer_to_block):
|
100 |
+
self.collect_edges()
|
101 |
+
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
|
102 |
+
if src_layer_index in layer_to_block:
|
103 |
+
src_block = layer_to_block[src_layer_index]
|
104 |
+
src_layer_offset = src_layer_index - src_block.layer_range.start
|
105 |
+
dst = (self.type_id, dst_layer_index, dst_input_index)
|
106 |
+
self.relative_dst_inter_edges.append(
|
107 |
+
(src_block.type_id, src_layer_offset, src_output_index,
|
108 |
+
*dst))
|
109 |
+
else:
|
110 |
+
self.relative_dst_inter_edges.append(
|
111 |
+
(-1, src_layer_index, src_output_index, self.type_id,
|
112 |
+
dst_layer_index, dst_input_index))
|
113 |
+
self.relative_inter_edges = set(self.relative_dst_inter_edges +
|
114 |
+
self.outputs)
|
115 |
+
|
116 |
+
def get_input_names(self):
|
117 |
+
self.collect_edges()
|
118 |
+
input_tensor_names = []
|
119 |
+
for edge in self.dst_inter_edges:
|
120 |
+
layer_index = edge[0]
|
121 |
+
output_index = edge[1]
|
122 |
+
if layer_index == -1:
|
123 |
+
tensor_name = self.network.get_input(output_index).name
|
124 |
+
else:
|
125 |
+
tensor_name = self.network.get_layer(layer_index).get_output(
|
126 |
+
output_index).name
|
127 |
+
input_tensor_names.append(tensor_name)
|
128 |
+
return input_tensor_names
|
129 |
+
|
130 |
+
def get_input_mapping(self, last_blocks):
|
131 |
+
input_mapping = {}
|
132 |
+
for tensor_name, relative_edge in zip(self.get_input_names(),
|
133 |
+
self.relative_dst_inter_edges):
|
134 |
+
type_id = relative_edge[0]
|
135 |
+
output_index = relative_edge[2]
|
136 |
+
if type_id >= 0:
|
137 |
+
last_block = last_blocks[type_id]
|
138 |
+
layer_offset = relative_edge[1]
|
139 |
+
mapped_layer_index = last_block.layer_range.start + layer_offset
|
140 |
+
mapped_tensor_name = self.network.get_layer(
|
141 |
+
mapped_layer_index).get_output(output_index).name
|
142 |
+
input_mapping[tensor_name] = mapped_tensor_name
|
143 |
+
else:
|
144 |
+
input_mapping[tensor_name] = tensor_name
|
145 |
+
return input_mapping
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class GraphMapping:
|
150 |
+
layer_mapping: Dict[int, int] = None
|
151 |
+
block_mapping: Dict[int, int] = None
|
152 |
+
p2p_types: Dict[int, P2PType] = None
|
153 |
+
p2p_tensors: Dict[int, List[str]] = None
|
154 |
+
block_to_stage: Dict[int, int] = None
|
155 |
+
same_spec_layer_mapping: Dict[str, str] = None
|
156 |
+
|
157 |
+
|
158 |
+
@dataclass
|
159 |
+
class GraphConfig:
|
160 |
+
num_micro_batches: int = 1
|
161 |
+
num_blocks: int = 1
|
162 |
+
num_stages: int = 1
|
163 |
+
has_cross_device: bool = False
|
164 |
+
has_cross_host: bool = False
|
165 |
+
graph_mapping: GraphMapping = None
|
166 |
+
phy_mesh: PhysicalDeviceMesh = None
|
167 |
+
stage_phy_meshes: List[PhysicalDeviceMesh] = None
|
168 |
+
|
169 |
+
|
170 |
+
class Simplifier:
|
171 |
+
|
172 |
+
def __init__(self, network: Network, config: AutoParallelConfig):
|
173 |
+
self.config = config
|
174 |
+
self.sharded_io_allowlist = config.sharded_io_allowlist
|
175 |
+
self.same_buffer_io = config.same_buffer_io
|
176 |
+
self.same_spec_io = config.same_spec_io.copy()
|
177 |
+
for key, value in self.same_buffer_io.items():
|
178 |
+
if key not in self.same_spec_io:
|
179 |
+
self.same_spec_io[key] = value
|
180 |
+
|
181 |
+
self.llm_network = network
|
182 |
+
self.network = network.trt_network
|
183 |
+
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
|
184 |
+
self.graph = self.get_graph()
|
185 |
+
self.init_layer_hash()
|
186 |
+
|
187 |
+
module_tree = self.get_module_tree()
|
188 |
+
building_blocks = self.collect_building_blocks(module_tree)
|
189 |
+
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
|
190 |
+
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
|
191 |
+
blocks_by_module_hash)
|
192 |
+
self.layer_to_block = self.get_layer_to_block()
|
193 |
+
self.blocks = self.get_all_blocks()
|
194 |
+
self.backbone_blocks = self.get_backbone_blocks()
|
195 |
+
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
|
196 |
+
self.graph_for_shape = self.create_simplified_graph_for_shape()
|
197 |
+
self.shape_info = None
|
198 |
+
self.num_micro_batches = None
|
199 |
+
|
200 |
+
def infer_shapes(self, num_micro_batches):
|
201 |
+
if self.num_micro_batches == num_micro_batches:
|
202 |
+
return
|
203 |
+
with silent_trt_logger():
|
204 |
+
self.shape_info = self.get_full_shape_info(num_micro_batches)
|
205 |
+
self.graph.assign_shapes(self.shape_info)
|
206 |
+
self.num_micro_batches = num_micro_batches
|
207 |
+
|
208 |
+
def list_all_num_micro_batches(self):
|
209 |
+
opt_batch_size = self.get_opt_batch_size()
|
210 |
+
candidates = []
|
211 |
+
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
|
212 |
+
if opt_batch_size % num_micro_batches == 0:
|
213 |
+
candidates.append(num_micro_batches)
|
214 |
+
return candidates
|
215 |
+
|
216 |
+
def get_graph(self):
|
217 |
+
graph = PipelineGraph.from_trt(self.network)
|
218 |
+
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
|
219 |
+
graph._io_buffer_mapping
|
220 |
+
for input in graph.inputs:
|
221 |
+
input_name = input.name
|
222 |
+
for pattern, repl in self.same_buffer_io.items():
|
223 |
+
if re.match(pattern, input_name):
|
224 |
+
output_name = re.sub(pattern, repl, input_name)
|
225 |
+
output = graph.get_output(output_name)
|
226 |
+
if output is not None:
|
227 |
+
graph._io_buffer_mapping[output_name] = input_name
|
228 |
+
return graph
|
229 |
+
|
230 |
+
def get_opt_batch_size(self):
|
231 |
+
input_tensors = self.llm_network._inputs
|
232 |
+
num_profiles = len(list(input_tensors.values())[0].profiles)
|
233 |
+
opt_batch_sizes = []
|
234 |
+
for i in range(num_profiles):
|
235 |
+
for input_tensor in input_tensors.values():
|
236 |
+
shape_profile = input_tensor.profiles[i]
|
237 |
+
opt_shape = shape_profile.opt
|
238 |
+
for j in range(len(input_tensor.shape)):
|
239 |
+
name = input_tensor.trt_tensor.get_dimension_name(j)
|
240 |
+
if name == 'batch_size':
|
241 |
+
opt_batch_sizes.append(opt_shape[j])
|
242 |
+
return min(opt_batch_sizes)
|
243 |
+
|
244 |
+
def get_module_hash(self, layer_range):
|
245 |
+
module_hash = ()
|
246 |
+
for i in layer_range:
|
247 |
+
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
|
248 |
+
layer_name = self.network.get_layer(i).name
|
249 |
+
layer = self.graph.get_layer(layer_name)
|
250 |
+
module_hash += (layer.attrs["hash"], )
|
251 |
+
return module_hash
|
252 |
+
|
253 |
+
def get_network_hash(self) -> str:
|
254 |
+
return str(self.get_module_hash(range(self.network.num_layers)))
|
255 |
+
|
256 |
+
def collect_building_blocks(self, module_tree):
|
257 |
+
building_blocks = {}
|
258 |
+
queue = []
|
259 |
+
for tree in module_tree["children"].values():
|
260 |
+
queue.append(tree)
|
261 |
+
while len(queue) > 0:
|
262 |
+
while len(queue) > 0:
|
263 |
+
tree = queue.pop(0)
|
264 |
+
module_name = tree["name"]
|
265 |
+
if module_name is None:
|
266 |
+
for child in tree["children"].values():
|
267 |
+
queue.append(child)
|
268 |
+
continue
|
269 |
+
layer_range = self.module_to_layer_range_map[module_name]
|
270 |
+
module_hash = self.get_module_hash(layer_range)
|
271 |
+
if module_hash in building_blocks:
|
272 |
+
building_blocks[module_hash].append(tree)
|
273 |
+
else:
|
274 |
+
building_blocks[module_hash] = [tree]
|
275 |
+
for module_hash in [*building_blocks.keys()]:
|
276 |
+
if len(building_blocks[module_hash]) == 1:
|
277 |
+
tree = building_blocks[module_hash][0]
|
278 |
+
for child in tree["children"].values():
|
279 |
+
queue.append(child)
|
280 |
+
del building_blocks[module_hash]
|
281 |
+
blocks_by_module_hash = {
|
282 |
+
module_hash: [
|
283 |
+
BuildingBlock(self.graph,
|
284 |
+
self.module_to_layer_range_map[tree["name"]])
|
285 |
+
for tree in trees
|
286 |
+
]
|
287 |
+
for module_hash, trees in building_blocks.items()
|
288 |
+
}
|
289 |
+
building_blocks = []
|
290 |
+
for block_list in blocks_by_module_hash.values():
|
291 |
+
for block in block_list:
|
292 |
+
building_blocks.append(block)
|
293 |
+
building_blocks = sorted(building_blocks,
|
294 |
+
key=lambda x: x.layer_range.start)
|
295 |
+
if len(building_blocks) >= 2:
|
296 |
+
for block, next_block in zip(building_blocks[:-1],
|
297 |
+
building_blocks[1:]):
|
298 |
+
block.layer_range = range(block.layer_range.start,
|
299 |
+
next_block.layer_range.start)
|
300 |
+
return building_blocks
|
301 |
+
|
302 |
+
def get_all_blocks(self):
|
303 |
+
building_blocks = []
|
304 |
+
for block_list in self.blocks_by_edge_hash.values():
|
305 |
+
for block in block_list:
|
306 |
+
building_blocks.append(block)
|
307 |
+
building_blocks = sorted(building_blocks,
|
308 |
+
key=lambda x: x.layer_range.start)
|
309 |
+
all_blocks = []
|
310 |
+
current_layer_index = 0
|
311 |
+
block_id = 0
|
312 |
+
for block in building_blocks:
|
313 |
+
assert current_layer_index <= block.layer_range.start
|
314 |
+
if current_layer_index < block.layer_range.start:
|
315 |
+
new_block = BuildingBlock(
|
316 |
+
self.graph,
|
317 |
+
range(current_layer_index, block.layer_range.start))
|
318 |
+
new_block.block_id = block_id
|
319 |
+
block_id += 1
|
320 |
+
all_blocks.append(new_block)
|
321 |
+
block.block_id = block_id
|
322 |
+
block_id += 1
|
323 |
+
all_blocks.append(block)
|
324 |
+
current_layer_index = block.layer_range.stop
|
325 |
+
if current_layer_index < self.graph.num_layers:
|
326 |
+
new_block = BuildingBlock(
|
327 |
+
self.graph, range(current_layer_index, self.graph.num_layers))
|
328 |
+
new_block.block_id = block_id
|
329 |
+
all_blocks.append(new_block)
|
330 |
+
sorted_layer_ids = get_sorted_layer_ids(self.network)
|
331 |
+
for block in all_blocks:
|
332 |
+
block.collect_relative_inter_edges(self.layer_to_block)
|
333 |
+
for layer_id in sorted_layer_ids:
|
334 |
+
if layer_id in block.layer_range:
|
335 |
+
block.sorted_layer_ids.append(layer_id)
|
336 |
+
return all_blocks
|
337 |
+
|
338 |
+
def get_backbone_blocks(self):
|
339 |
+
sorted_blocks = sorted(
|
340 |
+
self.blocks_by_edge_hash.values(),
|
341 |
+
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
|
342 |
+
)
|
343 |
+
if len(sorted_blocks) == 0:
|
344 |
+
return []
|
345 |
+
else:
|
346 |
+
return sorted_blocks[-1]
|
347 |
+
|
348 |
+
def get_blocks_by_module_hash(self, blocks):
|
349 |
+
blocks_by_module_hash = {}
|
350 |
+
for block in blocks:
|
351 |
+
module_hash = self.get_module_hash(block.layer_range)
|
352 |
+
if module_hash not in blocks_by_module_hash:
|
353 |
+
blocks_by_module_hash[module_hash] = []
|
354 |
+
blocks_by_module_hash[module_hash].append(block)
|
355 |
+
for module_hash in [*blocks_by_module_hash.keys()]:
|
356 |
+
if len(blocks_by_module_hash[module_hash]) == 1:
|
357 |
+
del blocks_by_module_hash[module_hash]
|
358 |
+
return blocks_by_module_hash
|
359 |
+
|
360 |
+
def get_module_tree(self):
|
361 |
+
module_tree = {"children": {}, "name": None}
|
362 |
+
for module_name in self.module_to_layer_range_map.keys():
|
363 |
+
full_name = module_name.split('.')
|
364 |
+
current_tree = module_tree["children"]
|
365 |
+
for depth, name in enumerate(full_name):
|
366 |
+
if name not in current_tree:
|
367 |
+
current_tree[name] = {"children": {}, "name": None}
|
368 |
+
if depth == len(full_name) - 1:
|
369 |
+
current_tree[name]["name"] = module_name
|
370 |
+
else:
|
371 |
+
current_tree = current_tree[name]["children"]
|
372 |
+
return module_tree
|
373 |
+
|
374 |
+
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
|
375 |
+
blocks_by_edge_hash = {}
|
376 |
+
for block_list in blocks_by_module_hash.values():
|
377 |
+
for block in block_list:
|
378 |
+
block.collect_edges()
|
379 |
+
edge_hash = block.edge_hash
|
380 |
+
if edge_hash not in blocks_by_edge_hash:
|
381 |
+
blocks_by_edge_hash[edge_hash] = []
|
382 |
+
blocks_by_edge_hash[edge_hash].append(block)
|
383 |
+
for edge_hash in [*blocks_by_edge_hash.keys()]:
|
384 |
+
if len(blocks_by_edge_hash[edge_hash]) == 1:
|
385 |
+
del blocks_by_edge_hash[edge_hash]
|
386 |
+
else:
|
387 |
+
block_list = blocks_by_edge_hash[edge_hash]
|
388 |
+
blocks_by_edge_hash[edge_hash] = sorted(
|
389 |
+
block_list, key=lambda x: x.layer_range.start)
|
390 |
+
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
|
391 |
+
for block in block_list:
|
392 |
+
block.type_id = type_id
|
393 |
+
return blocks_by_edge_hash
|
394 |
+
|
395 |
+
def get_layer_to_block(self):
|
396 |
+
layer_to_block = {}
|
397 |
+
for block_list in self.blocks_by_edge_hash.values():
|
398 |
+
for block in block_list:
|
399 |
+
for layer_index in block.layer_range:
|
400 |
+
layer_to_block[layer_index] = block
|
401 |
+
return layer_to_block
|
402 |
+
|
403 |
+
def clean_blocks(self):
|
404 |
+
for block in self.blocks:
|
405 |
+
block.p2p_type = None
|
406 |
+
block.is_superset = False
|
407 |
+
block.is_subset = False
|
408 |
+
|
409 |
+
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
|
410 |
+
graph_config: GraphConfig):
|
411 |
+
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
|
412 |
+
return
|
413 |
+
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
|
414 |
+
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
|
415 |
+
|
416 |
+
for block in self.backbone_blocks:
|
417 |
+
block.p2p_type = None
|
418 |
+
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
|
419 |
+
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
|
420 |
+
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
|
421 |
+
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
|
422 |
+
)[0]
|
423 |
+
num_devices_per_host = phy_mesh.num_devices_per_host
|
424 |
+
next_block = self.backbone_blocks[(stage_index + 1) *
|
425 |
+
block_per_stage]
|
426 |
+
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
|
427 |
+
next_block.p2p_type = P2PType.CROSS_HOST
|
428 |
+
graph_config.has_cross_host = True
|
429 |
+
else:
|
430 |
+
next_block.p2p_type = P2PType.CROSS_DEVICE
|
431 |
+
graph_config.has_cross_device = True
|
432 |
+
|
433 |
+
def get_graph_mapping(self):
|
434 |
+
layer_mapping = {}
|
435 |
+
block_mapping = {}
|
436 |
+
p2p_types = {}
|
437 |
+
p2p_tensors = {}
|
438 |
+
for block_list in self.blocks_by_edge_hash.values():
|
439 |
+
superset_blocks = []
|
440 |
+
superset_block_index = {}
|
441 |
+
for block in block_list:
|
442 |
+
block_added = False
|
443 |
+
for index, superset_block in enumerate(list(superset_blocks)):
|
444 |
+
if block.p2p_type == superset_block.p2p_type:
|
445 |
+
if block.relative_inter_edges.issubset(
|
446 |
+
superset_block.relative_inter_edges):
|
447 |
+
block.is_subset = True
|
448 |
+
block.is_superset = False
|
449 |
+
superset_block_index[id(block)] = index
|
450 |
+
block_added = True
|
451 |
+
break
|
452 |
+
elif superset_block.relative_inter_edges.issubset(
|
453 |
+
block.relative_inter_edges):
|
454 |
+
superset_block.is_subset = True
|
455 |
+
superset_block.is_superset = False
|
456 |
+
block.is_subset = False
|
457 |
+
block.is_superset = True
|
458 |
+
superset_blocks[index] = block
|
459 |
+
superset_block_index[id(block)] = index
|
460 |
+
block_added = True
|
461 |
+
break
|
462 |
+
if not block_added:
|
463 |
+
block.is_subset = False
|
464 |
+
block.is_superset = True
|
465 |
+
superset_blocks.append(block)
|
466 |
+
superset_block_index[id(block)] = len(superset_blocks) - 1
|
467 |
+
for block in block_list:
|
468 |
+
assert not (block.is_subset and block.is_superset)
|
469 |
+
if block.is_subset:
|
470 |
+
superset_block = superset_blocks[superset_block_index[id(
|
471 |
+
block)]]
|
472 |
+
block_mapping[block.block_id] = superset_block.block_id
|
473 |
+
owned_inputs = map(
|
474 |
+
lambda x: x[0],
|
475 |
+
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
|
476 |
+
superset_owned_inputs = map(
|
477 |
+
lambda x: x[0],
|
478 |
+
sorted(superset_block.owned_inputs.items(),
|
479 |
+
key=lambda x: x[1]))
|
480 |
+
for from_input_id, to_input_id in zip(
|
481 |
+
owned_inputs, superset_owned_inputs):
|
482 |
+
from_input_name = self.network.get_input(
|
483 |
+
from_input_id).name
|
484 |
+
to_input_name = self.network.get_input(to_input_id).name
|
485 |
+
layer_mapping[from_input_name] = to_input_name
|
486 |
+
for from_layer_id, to_layer_id in zip(
|
487 |
+
block.layer_range, superset_block.layer_range):
|
488 |
+
from_layer = self.network.get_layer(from_layer_id)
|
489 |
+
to_layer = self.network.get_layer(to_layer_id)
|
490 |
+
layer_mapping[from_layer.name] = to_layer.name
|
491 |
+
for i in range(from_layer.num_outputs):
|
492 |
+
from_output = from_layer.get_output(i)
|
493 |
+
if from_output.is_network_output:
|
494 |
+
to_output = to_layer.get_output(i)
|
495 |
+
layer_mapping[from_output.name] = to_output.name
|
496 |
+
if block.p2p_type is not None:
|
497 |
+
p2p_types[block.block_id] = block.p2p_type
|
498 |
+
p2p_tensors[block.block_id] = [
|
499 |
+
*set(block.get_input_names())
|
500 |
+
]
|
501 |
+
for from_name, to_name in zip(
|
502 |
+
block.get_input_names(),
|
503 |
+
superset_block.get_input_names()):
|
504 |
+
layer_mapping[
|
505 |
+
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
|
506 |
+
stage_id = 0
|
507 |
+
block_to_stage = {}
|
508 |
+
for block in self.blocks:
|
509 |
+
if block.p2p_type is not None:
|
510 |
+
stage_id += 1
|
511 |
+
block_to_stage[block.block_id] = stage_id
|
512 |
+
return GraphMapping(
|
513 |
+
layer_mapping,
|
514 |
+
block_mapping,
|
515 |
+
p2p_types,
|
516 |
+
p2p_tensors,
|
517 |
+
block_to_stage,
|
518 |
+
)
|
519 |
+
|
520 |
+
def create_simplified_graph(self, graph_config: GraphConfig):
|
521 |
+
new_graph = PipelineGraph.create_graph()
|
522 |
+
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
|
523 |
+
layer_mapping = graph_config.graph_mapping.layer_mapping
|
524 |
+
|
525 |
+
for i in range(self.network.num_inputs):
|
526 |
+
trt_input = self.network.get_input(i)
|
527 |
+
if trt_input.name not in layer_mapping:
|
528 |
+
new_graph.add_input(trt_input)
|
529 |
+
|
530 |
+
last_blocks = {}
|
531 |
+
same_spec_mapping = {}
|
532 |
+
same_spec_layer_mapping = {}
|
533 |
+
shape_mapping = {}
|
534 |
+
building_block_id = 0
|
535 |
+
same_spec_ids = {}
|
536 |
+
same_spec_count = 0
|
537 |
+
for block in self.blocks:
|
538 |
+
if not block.is_subset:
|
539 |
+
stage_type = None
|
540 |
+
if not block.is_superset:
|
541 |
+
if block.block_id == 0:
|
542 |
+
stage_type = StageType.START
|
543 |
+
elif block.block_id == len(self.blocks) - 1:
|
544 |
+
stage_type = StageType.END
|
545 |
+
input_mapping = block.get_input_mapping(last_blocks)
|
546 |
+
for from_name, to_name in [*input_mapping.items()]:
|
547 |
+
if to_name in same_spec_mapping:
|
548 |
+
input_mapping[from_name] = same_spec_mapping[to_name]
|
549 |
+
if to_name in layer_mapping:
|
550 |
+
input_mapping[from_name] = layer_mapping[to_name]
|
551 |
+
if block.is_superset and block.p2p_type is not None:
|
552 |
+
for from_name, to_name in [*input_mapping.items()]:
|
553 |
+
output_tensor = new_graph.get_tensor(to_name)
|
554 |
+
p2p_layer = new_graph.as_trt().add_identity(
|
555 |
+
output_tensor.as_trt())
|
556 |
+
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
|
557 |
+
p2p_layer.metadata = p2p_layer.name
|
558 |
+
p2p_tensor = p2p_layer.get_output(0)
|
559 |
+
p2p_tensor.name = f"{p2p_layer.name}_output"
|
560 |
+
wrapped_layer = new_graph.register_layer(p2p_layer)
|
561 |
+
wrapped_layer.attrs[
|
562 |
+
"building_block_id"] = building_block_id
|
563 |
+
wrapped_layer.attrs["p2p_type"] = block.p2p_type
|
564 |
+
input_mapping[from_name] = p2p_tensor.name
|
565 |
+
shape_mapping[p2p_tensor.name] = from_name
|
566 |
+
building_block_id += 1
|
567 |
+
for i in block.sorted_layer_ids:
|
568 |
+
layer = self.network.get_layer(i)
|
569 |
+
wrapped_layer = new_graph.add_layer(
|
570 |
+
layer,
|
571 |
+
input_mapping=input_mapping,
|
572 |
+
)
|
573 |
+
wrapped_layer.attrs["building_block_id"] = building_block_id
|
574 |
+
wrapped_layer.attrs["stage_type"] = stage_type
|
575 |
+
if block.is_superset:
|
576 |
+
last_blocks[block.type_id] = block
|
577 |
+
|
578 |
+
if block.type_id in same_spec_ids:
|
579 |
+
same_spec_id = same_spec_ids[block.type_id]
|
580 |
+
update_same_spec_count = False
|
581 |
+
else:
|
582 |
+
same_spec_id = same_spec_count
|
583 |
+
same_spec_ids[block.type_id] = same_spec_id
|
584 |
+
update_same_spec_count = True
|
585 |
+
count = same_spec_id
|
586 |
+
for i, (layer_offset,
|
587 |
+
output_index) in enumerate(block.outputs):
|
588 |
+
layer = self.network.get_layer(block.layer_range.start +
|
589 |
+
layer_offset)
|
590 |
+
tensor_name = layer.get_output(output_index).name
|
591 |
+
output_tensor = new_graph.get_tensor(tensor_name)
|
592 |
+
same_spec_layer = new_graph.as_trt().add_identity(
|
593 |
+
output_tensor.as_trt())
|
594 |
+
same_spec_layer.name = f"{tensor_name}_same_spec"
|
595 |
+
same_spec_layer.metadata = same_spec_layer.name
|
596 |
+
same_spec_tensor = same_spec_layer.get_output(0)
|
597 |
+
same_spec_tensor.name = f"{same_spec_layer.name}_output"
|
598 |
+
wrapped_layer = new_graph.register_layer(
|
599 |
+
same_spec_layer)
|
600 |
+
wrapped_layer.attrs[
|
601 |
+
"building_block_id"] = building_block_id
|
602 |
+
wrapped_layer.attrs["same_spec_id"] = count
|
603 |
+
count += 1
|
604 |
+
same_spec_mapping[tensor_name] = same_spec_tensor.name
|
605 |
+
same_spec_layer_mapping[
|
606 |
+
same_spec_layer.name] = layer.name
|
607 |
+
shape_mapping[same_spec_tensor.name] = tensor_name
|
608 |
+
for i, graph_input_index in enumerate(
|
609 |
+
block.owned_inputs.keys()):
|
610 |
+
input_name = self.network.get_input(
|
611 |
+
graph_input_index).name
|
612 |
+
input_tensor = new_graph.get_input(input_name)
|
613 |
+
input_tensor.attrs["same_spec_id"] = count
|
614 |
+
count += 1
|
615 |
+
if update_same_spec_count:
|
616 |
+
same_spec_count = count
|
617 |
+
building_block_id += 1
|
618 |
+
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
|
619 |
+
|
620 |
+
if len(self.backbone_blocks) >= 2:
|
621 |
+
start_block = self.backbone_blocks[0]
|
622 |
+
if start_block.is_subset:
|
623 |
+
start_block = self.blocks[graph_config.graph_mapping.
|
624 |
+
block_mapping[start_block.block_id]]
|
625 |
+
for i in start_block.layer_range:
|
626 |
+
layer_name = self.network.get_layer(i).name
|
627 |
+
layer = new_graph.get_layer(layer_name)
|
628 |
+
layer.attrs["in_start_block"] = True
|
629 |
+
end_block = self.backbone_blocks[-1]
|
630 |
+
if end_block.is_subset:
|
631 |
+
end_block = self.blocks[graph_config.graph_mapping.
|
632 |
+
block_mapping[end_block.block_id]]
|
633 |
+
for i in end_block.layer_range:
|
634 |
+
layer_name = self.network.get_layer(i).name
|
635 |
+
layer = new_graph.get_layer(layer_name)
|
636 |
+
layer.attrs["in_end_block"] = True
|
637 |
+
slowest_p2p_type = None
|
638 |
+
if graph_config.has_cross_host:
|
639 |
+
slowest_p2p_type = P2PType.CROSS_HOST
|
640 |
+
elif graph_config.has_cross_device:
|
641 |
+
slowest_p2p_type = P2PType.CROSS_DEVICE
|
642 |
+
if slowest_p2p_type is not None:
|
643 |
+
for block in self.blocks:
|
644 |
+
if block.is_superset and block.p2p_type == slowest_p2p_type:
|
645 |
+
for i in block.layer_range:
|
646 |
+
layer_name = self.network.get_layer(i).name
|
647 |
+
layer = new_graph.get_layer(layer_name)
|
648 |
+
layer.attrs["in_slowest_block"] = True
|
649 |
+
|
650 |
+
for i in range(self.network.num_outputs):
|
651 |
+
trt_output = self.network.get_output(i)
|
652 |
+
output = self.graph.get_output(trt_output.name)
|
653 |
+
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
|
654 |
+
output.producer.index].is_subset:
|
655 |
+
continue
|
656 |
+
if trt_output.is_shape_tensor:
|
657 |
+
new_output = new_graph.add_output_shape(trt_output)
|
658 |
+
else:
|
659 |
+
new_output = new_graph.add_output(trt_output)
|
660 |
+
sharded_io = False
|
661 |
+
for pattern in self.sharded_io_allowlist:
|
662 |
+
if re.match(pattern, new_output.name):
|
663 |
+
sharded_io = True
|
664 |
+
break
|
665 |
+
if not sharded_io:
|
666 |
+
new_output.producer.attrs["is_replicated"] = True
|
667 |
+
|
668 |
+
for input in new_graph.inputs:
|
669 |
+
input_name = input.name
|
670 |
+
sharded_io = False
|
671 |
+
for pattern in self.sharded_io_allowlist:
|
672 |
+
if re.match(pattern, input_name):
|
673 |
+
sharded_io = True
|
674 |
+
break
|
675 |
+
if not sharded_io:
|
676 |
+
input.attrs["is_replicated"] = True
|
677 |
+
for pattern, repl in self.same_spec_io.items():
|
678 |
+
if re.match(pattern, input_name):
|
679 |
+
output_name = re.sub(pattern, repl, input_name)
|
680 |
+
output = new_graph.get_output(output_name)
|
681 |
+
if output is not None:
|
682 |
+
if "same_spec_id" in input.attrs:
|
683 |
+
same_spec_id = input.attrs["same_spec_id"]
|
684 |
+
else:
|
685 |
+
same_spec_id = same_spec_count
|
686 |
+
same_spec_count += 1
|
687 |
+
input.attrs["same_spec_id"] = same_spec_id
|
688 |
+
output.attrs["same_spec_id"] = same_spec_id
|
689 |
+
if math.prod(self.graph.get_input(
|
690 |
+
input_name).shape) < math.prod(
|
691 |
+
self.graph.get_output(output_name).shape):
|
692 |
+
input.attrs["no_memory_footprint"] = True
|
693 |
+
else:
|
694 |
+
output.attrs["no_memory_footprint"] = True
|
695 |
+
|
696 |
+
return new_graph, shape_mapping
|
697 |
+
|
698 |
+
def enrich_shape_info(self, shape_mapping):
|
699 |
+
shapes = self.shape_info.shapes.copy()
|
700 |
+
max_shapes = self.shape_info.max_shapes.copy()
|
701 |
+
values = self.shape_info.values.copy()
|
702 |
+
shape_layers = self.shape_info.shape_layers
|
703 |
+
for from_name, to_name in shape_mapping.items():
|
704 |
+
if to_name in shapes:
|
705 |
+
shapes[from_name] = shapes[to_name]
|
706 |
+
if to_name in max_shapes:
|
707 |
+
max_shapes[from_name] = max_shapes[to_name]
|
708 |
+
if to_name in values:
|
709 |
+
values[from_name] = values[to_name]
|
710 |
+
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
|
711 |
+
return shape_info
|
712 |
+
|
713 |
+
def simplify_graph(
|
714 |
+
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
|
715 |
+
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
|
716 |
+
num_blocks = len(self.backbone_blocks)
|
717 |
+
if num_blocks % num_stages != 0:
|
718 |
+
return None, None
|
719 |
+
graph_config = GraphConfig()
|
720 |
+
graph_config.num_micro_batches = self.num_micro_batches
|
721 |
+
graph_config.num_blocks = num_blocks
|
722 |
+
graph_config.num_stages = num_stages
|
723 |
+
graph_config.phy_mesh = phy_mesh
|
724 |
+
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
|
725 |
+
num_stages, num_devices_per_stage)
|
726 |
+
graph_config.stage_phy_meshes = stage_phy_meshes
|
727 |
+
with silent_trt_logger():
|
728 |
+
self.clean_blocks()
|
729 |
+
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
|
730 |
+
graph_config.graph_mapping = self.get_graph_mapping()
|
731 |
+
new_graph, shape_mapping = self.create_simplified_graph(
|
732 |
+
graph_config)
|
733 |
+
shape_info = self.enrich_shape_info(shape_mapping)
|
734 |
+
new_graph.assign_shapes(shape_info)
|
735 |
+
return new_graph, graph_config
|
736 |
+
|
737 |
+
def get_graph_mapping_for_shape(self):
|
738 |
+
layer_mapping = {}
|
739 |
+
tensor_mapping = {}
|
740 |
+
for block_list in self.blocks_by_edge_hash.values():
|
741 |
+
head_block = block_list[0]
|
742 |
+
for block in block_list[1:]:
|
743 |
+
for from_layer_id, to_layer_id in zip(block.layer_range,
|
744 |
+
head_block.layer_range):
|
745 |
+
from_layer = self.network.get_layer(from_layer_id)
|
746 |
+
to_layer = self.network.get_layer(to_layer_id)
|
747 |
+
layer_mapping[from_layer.name] = to_layer.name
|
748 |
+
for i in range(from_layer.num_outputs):
|
749 |
+
tensor_mapping[from_layer.get_output(
|
750 |
+
i).name] = to_layer.get_output(i).name
|
751 |
+
return layer_mapping, tensor_mapping
|
752 |
+
|
753 |
+
def create_simplified_graph_for_shape(self):
|
754 |
+
new_graph = PipelineGraph.create_graph()
|
755 |
+
|
756 |
+
for i in range(self.network.num_inputs):
|
757 |
+
trt_input = self.network.get_input(i)
|
758 |
+
new_graph.add_input(trt_input)
|
759 |
+
|
760 |
+
head_blocks = {}
|
761 |
+
removed_blocks = set()
|
762 |
+
removed_layers = set()
|
763 |
+
for block_list in self.blocks_by_edge_hash.values():
|
764 |
+
head_block = block_list[0]
|
765 |
+
head_blocks[head_block.type_id] = head_block
|
766 |
+
for block in block_list[1:]:
|
767 |
+
removed_blocks.add(id(block))
|
768 |
+
for layer_index in block.layer_range:
|
769 |
+
removed_layers.add(layer_index)
|
770 |
+
|
771 |
+
for block in self.blocks:
|
772 |
+
if not id(block) in removed_blocks:
|
773 |
+
input_mapping = block.get_input_mapping(head_blocks)
|
774 |
+
for i in block.sorted_layer_ids:
|
775 |
+
layer = self.network.get_layer(i)
|
776 |
+
new_graph.add_layer(
|
777 |
+
layer,
|
778 |
+
input_mapping=input_mapping,
|
779 |
+
)
|
780 |
+
|
781 |
+
for i in range(self.network.num_outputs):
|
782 |
+
trt_output = self.network.get_output(i)
|
783 |
+
output = self.graph.get_output(trt_output.name)
|
784 |
+
if output.producer is not None and output.producer.index in removed_layers:
|
785 |
+
continue
|
786 |
+
if trt_output.is_shape_tensor:
|
787 |
+
new_graph.add_output_shape(trt_output)
|
788 |
+
else:
|
789 |
+
new_graph.add_output(trt_output)
|
790 |
+
|
791 |
+
return new_graph
|
792 |
+
|
793 |
+
def get_full_shape_info(self, num_micro_batches):
|
794 |
+
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
|
795 |
+
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
796 |
+
)
|
797 |
+
if len(optimization_profiles) > 0:
|
798 |
+
optimization_profile = optimization_profiles[-1]
|
799 |
+
else:
|
800 |
+
optimization_profile = None
|
801 |
+
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
802 |
+
optimization_profile)
|
803 |
+
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
804 |
+
optimization_profile,
|
805 |
+
shape_type=ShapeType.MAX)
|
806 |
+
shape_info.max_shapes = max_shape_info.shapes
|
807 |
+
for removed_tensor_name, tensor_name in tensor_mapping.items():
|
808 |
+
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
|
809 |
+
tensor_name]
|
810 |
+
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
|
811 |
+
tensor_name]
|
812 |
+
if tensor_name in shape_info.values:
|
813 |
+
shape_info.values[removed_tensor_name] = shape_info.values[
|
814 |
+
tensor_name]
|
815 |
+
for removed_layer_name, layer_name in layer_mapping.items():
|
816 |
+
if layer_name in shape_info.shape_layers:
|
817 |
+
shape_info.shape_layers.add(removed_layer_name)
|
818 |
+
return shape_info
|
819 |
+
|
820 |
+
def init_layer_hash(self):
|
821 |
+
with silent_trt_logger():
|
822 |
+
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
823 |
+
)
|
824 |
+
if len(optimization_profiles) > 0:
|
825 |
+
optimization_profile = optimization_profiles[-1]
|
826 |
+
else:
|
827 |
+
optimization_profile = None
|
828 |
+
shape_info = get_shape_info(self.network, optimization_profile)
|
829 |
+
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
|
830 |
+
for layer in self.graph.layers:
|
831 |
+
layer_hash = get_cache_key(
|
832 |
+
layer.as_trt(),
|
833 |
+
shape_info.shapes,
|
834 |
+
shape_info.values,
|
835 |
+
dtypes,
|
836 |
+
)
|
837 |
+
layer.attrs["hash"] = layer_hash
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/solver.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes.
|
2 |
+
"""
|
3 |
+
import multiprocessing
|
4 |
+
import time
|
5 |
+
import warnings
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pulp
|
10 |
+
from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum
|
11 |
+
|
12 |
+
from ..logger import logger
|
13 |
+
|
14 |
+
|
15 |
+
class Solution:
|
16 |
+
|
17 |
+
def __init__(self, leaf_strategies, s_val, e_val, edge_pairs,
|
18 |
+
node_index_dict, total_cost):
|
19 |
+
self.leaf_strategies = leaf_strategies
|
20 |
+
self.nodes = [
|
21 |
+
strategies_vector.node for strategies_vector in self.leaf_strategies
|
22 |
+
]
|
23 |
+
self.s_val = s_val
|
24 |
+
self.e_val = e_val
|
25 |
+
self.total_cost = total_cost
|
26 |
+
self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2)))
|
27 |
+
self.node_index_dict = node_index_dict
|
28 |
+
self.index_node_dict = {}
|
29 |
+
for node, index in self.node_index_dict.items():
|
30 |
+
self.index_node_dict[index] = node
|
31 |
+
self.node_best_strategy = {}
|
32 |
+
self._annotate_strategy()
|
33 |
+
|
34 |
+
def _annotate_strategy(self):
|
35 |
+
self.node_best_strategy = {}
|
36 |
+
for index, node in enumerate(self.nodes):
|
37 |
+
best_strategy_id = self.s_val[index]
|
38 |
+
best_strategy = self.leaf_strategies[index][best_strategy_id]
|
39 |
+
self.node_best_strategy[node.node_name] = best_strategy
|
40 |
+
|
41 |
+
for edge_idx, edge_pair in enumerate(self.edge_pairs):
|
42 |
+
src_node = self.index_node_dict[edge_pair[0]]
|
43 |
+
dst_node = self.index_node_dict[edge_pair[1]]
|
44 |
+
src_node_index = self.node_index_dict[src_node]
|
45 |
+
for dst_pre_node in dst_node.predecessor_nodes:
|
46 |
+
if dst_pre_node is None:
|
47 |
+
continue
|
48 |
+
if src_node.node_name == dst_pre_node.node_name:
|
49 |
+
self.node_best_strategy[
|
50 |
+
dst_node.node_name].best_resharding_cost[
|
51 |
+
src_node.node_name] = [
|
52 |
+
self.node_best_strategy[dst_node.node_name].
|
53 |
+
resharding_costs[src_node.node_name][
|
54 |
+
self.s_val[src_node_index]]
|
55 |
+
]
|
56 |
+
|
57 |
+
def print_solution(self):
|
58 |
+
for index, node in enumerate(self.nodes):
|
59 |
+
best_strategy = self.node_best_strategy[node.node_name]
|
60 |
+
print(f'\n[{index}]: node_name = {node.node_name}')
|
61 |
+
best_strategy.print_strategy(best_resharding_cost_only=True)
|
62 |
+
print(f'solution total cost = {self.total_cost}')
|
63 |
+
|
64 |
+
|
65 |
+
class CostGraph:
|
66 |
+
'''
|
67 |
+
A graph data structure to simplify the edge cost graph. It has two main functions:
|
68 |
+
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
69 |
+
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
70 |
+
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
71 |
+
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
|
72 |
+
be given by the StrategiesVector depending on the type of target node and following nodes.
|
73 |
+
|
74 |
+
Argument:
|
75 |
+
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
76 |
+
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
77 |
+
'''
|
78 |
+
|
79 |
+
def __init__(self, leaf_strategies):
|
80 |
+
self.leaf_strategies = leaf_strategies
|
81 |
+
self.nodes = [
|
82 |
+
strategies_vector.node for strategies_vector in leaf_strategies
|
83 |
+
]
|
84 |
+
# stores number of strategies in each node
|
85 |
+
self.node_strategies_vector = {}
|
86 |
+
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
|
87 |
+
self.node_strategies_vector[node] = strategies_vector
|
88 |
+
# extra_node_costs will store the extra costs introduced by merging nodes
|
89 |
+
self.extra_node_costs = {}
|
90 |
+
self.following_dict = {}
|
91 |
+
self._build_cost_graph()
|
92 |
+
|
93 |
+
def _remove_invalid_node(self, node, attr_name):
|
94 |
+
remove_list = []
|
95 |
+
target_node_list = getattr(node, attr_name, [])
|
96 |
+
for target_node in target_node_list:
|
97 |
+
if target_node not in self.nodes:
|
98 |
+
remove_list.append(target_node)
|
99 |
+
for element in remove_list:
|
100 |
+
target_node_list.remove(element)
|
101 |
+
|
102 |
+
def _build_cost_graph(self):
|
103 |
+
'''
|
104 |
+
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
105 |
+
set to node.
|
106 |
+
'''
|
107 |
+
self.edge_costs = {}
|
108 |
+
for dst_node, strategies_vector in zip(self.nodes,
|
109 |
+
self.leaf_strategies):
|
110 |
+
# build edge_cost
|
111 |
+
for src_node in dst_node.predecessor_nodes:
|
112 |
+
if src_node is None:
|
113 |
+
continue
|
114 |
+
if src_node not in self.nodes:
|
115 |
+
continue
|
116 |
+
node_pair = (src_node, dst_node)
|
117 |
+
edge_cost = {}
|
118 |
+
for i in range(len(strategies_vector)):
|
119 |
+
for j in range(len(self.node_strategies_vector[src_node])):
|
120 |
+
resharding_cost = strategies_vector[i].resharding_costs[
|
121 |
+
src_node.node_name][j][-1]
|
122 |
+
edge_cost[(j, i)] = resharding_cost
|
123 |
+
self.edge_costs[node_pair] = edge_cost
|
124 |
+
|
125 |
+
def get_edge_cost(self, src_node, dst_node):
|
126 |
+
return self.edge_costs[(src_node, dst_node)]
|
127 |
+
|
128 |
+
|
129 |
+
class Solver:
|
130 |
+
INFINITY_COST = 1e13
|
131 |
+
|
132 |
+
def __init__(self,
|
133 |
+
cost_graph: CostGraph,
|
134 |
+
memory_budget: float = -1.0,
|
135 |
+
solution_numbers: int = 1,
|
136 |
+
memory_increasing_coefficient: float = 1.3,
|
137 |
+
verbose=False):
|
138 |
+
'''
|
139 |
+
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
140 |
+
Argument:
|
141 |
+
graph: The computing graph to be optimized.
|
142 |
+
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
143 |
+
cost_graph: A graph data structure to simplify the edge cost graph.
|
144 |
+
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
145 |
+
memory_budget: Memory constraint for the solution.
|
146 |
+
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
147 |
+
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
148 |
+
'''
|
149 |
+
self.cost_graph = cost_graph
|
150 |
+
self.leaf_strategies = cost_graph.leaf_strategies
|
151 |
+
self.nodes = cost_graph.nodes
|
152 |
+
self.memory_budget = memory_budget
|
153 |
+
self.solution_numbers = solution_numbers
|
154 |
+
if self.solution_numbers > 1:
|
155 |
+
self.memory_increasing_coefficient = memory_increasing_coefficient
|
156 |
+
else:
|
157 |
+
self.memory_increasing_coefficient = 1
|
158 |
+
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
159 |
+
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
160 |
+
# self.liveness_list = self.graph_analyser.liveness_analysis()
|
161 |
+
self.liveness_list = self.nodes
|
162 |
+
self.node_index_dict = self._generate_node_index_dict()
|
163 |
+
# The last solution vector of auto sharding.
|
164 |
+
self.last_s_val = None
|
165 |
+
# The last objective value of the best ILP solution.
|
166 |
+
self.last_objective = None
|
167 |
+
self.verbose = verbose
|
168 |
+
|
169 |
+
def _generate_node_index_dict(self):
|
170 |
+
node_index_dict = {}
|
171 |
+
for index, node in enumerate(self.nodes):
|
172 |
+
node_index_dict[node] = index
|
173 |
+
return node_index_dict
|
174 |
+
|
175 |
+
def _prepare_data_for_solver(self):
|
176 |
+
'''
|
177 |
+
Extract information from components for solver.
|
178 |
+
'''
|
179 |
+
node_nums = len(self.leaf_strategies)
|
180 |
+
memory_budget = self.memory_budget
|
181 |
+
|
182 |
+
# prepare strategies_len
|
183 |
+
strategies_len = []
|
184 |
+
for node in self.nodes:
|
185 |
+
strategies_len.append(
|
186 |
+
len(self.cost_graph.node_strategies_vector[node]))
|
187 |
+
strategies_len = np.array(strategies_len)
|
188 |
+
|
189 |
+
# prepare edge_pairs and resharding costs
|
190 |
+
edge_pairs = []
|
191 |
+
resharding_costs = []
|
192 |
+
edge_cost_level = []
|
193 |
+
edge_resharding_weights = []
|
194 |
+
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
195 |
+
src_node = pairs[0]
|
196 |
+
dst_node = pairs[1]
|
197 |
+
src_node_index = self.node_index_dict[src_node]
|
198 |
+
dst_node_index = self.node_index_dict[dst_node]
|
199 |
+
edge_pairs.append(src_node_index)
|
200 |
+
edge_pairs.append(dst_node_index)
|
201 |
+
edge_cost_level.append(
|
202 |
+
(dst_node.building_block_id, dst_node.cost_level))
|
203 |
+
for i in range(strategies_len[src_node_index]):
|
204 |
+
for j in range(strategies_len[dst_node_index]):
|
205 |
+
resharding_costs.append(edge_cost[(i, j)])
|
206 |
+
edge_resharding_weights.append(dst_node.resharding_weight +
|
207 |
+
dst_node.pipeline_weight)
|
208 |
+
edge_pairs = np.array(edge_pairs)
|
209 |
+
resharding_costs = np.array(resharding_costs)
|
210 |
+
edge_resharding_weights = np.array(edge_resharding_weights)
|
211 |
+
# prepare compute_costs, communication_costs and memory_costs
|
212 |
+
compute_costs = []
|
213 |
+
communication_costs = []
|
214 |
+
memory_costs = []
|
215 |
+
peak_act_memory_costs, constant_memory_costs = [], []
|
216 |
+
node_sharding_weights = []
|
217 |
+
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
|
218 |
+
for index, strategy in enumerate(strategies_vector):
|
219 |
+
compute_cost = strategy.sharding_cost
|
220 |
+
origin_communication_cost = strategy.communication_cost
|
221 |
+
memory_cost = strategy.const_memory_footprint * node.sharding_weight
|
222 |
+
peak_act_memory = strategy.peak_memory_footprint
|
223 |
+
# extract the memory cost in float from MemoryCost item and sum them up
|
224 |
+
compute_costs.append(compute_cost)
|
225 |
+
# node in extra_node_costs means it has some extra communication
|
226 |
+
# cost from node merging, so we need to add those extra communication
|
227 |
+
# cost into
|
228 |
+
|
229 |
+
communication_costs.append(origin_communication_cost)
|
230 |
+
peak_act_memory_costs.append(peak_act_memory)
|
231 |
+
constant_memory_costs.append(memory_cost)
|
232 |
+
node_sharding_weights.append(node.sharding_weight +
|
233 |
+
node.pipeline_weight)
|
234 |
+
|
235 |
+
compute_costs = np.array(compute_costs)
|
236 |
+
communication_costs = np.array(communication_costs)
|
237 |
+
memory_costs = np.array([constant_memory_costs, peak_act_memory_costs])
|
238 |
+
node_sharding_weights = np.array(node_sharding_weights)
|
239 |
+
same_spec_nodes_dict = defaultdict(list)
|
240 |
+
node_cost_level = []
|
241 |
+
for idx, node in enumerate(self.nodes):
|
242 |
+
if node.same_spec_id >= 0:
|
243 |
+
same_spec_nodes_dict[node.same_spec_id].append(idx)
|
244 |
+
node_cost_level.append((node.building_block_id, node.cost_level))
|
245 |
+
# omit initial value for nodes
|
246 |
+
s_init_np = None
|
247 |
+
following_nodes = [-1 for i in range(node_nums)]
|
248 |
+
liveness_set = self.nodes
|
249 |
+
alias_set = []
|
250 |
+
alias_convert_costs = None
|
251 |
+
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose
|
252 |
+
|
253 |
+
def _call_solver_serialized_args(self,
|
254 |
+
node_nums,
|
255 |
+
memory_budget,
|
256 |
+
strategies_len,
|
257 |
+
following_nodes,
|
258 |
+
edge_pairs,
|
259 |
+
alias_set,
|
260 |
+
liveness_set,
|
261 |
+
compute_costs,
|
262 |
+
communication_costs,
|
263 |
+
memory_costs,
|
264 |
+
resharding_costs,
|
265 |
+
node_sharding_weights,
|
266 |
+
edge_resharding_weights,
|
267 |
+
same_spec_nodes_dict,
|
268 |
+
node_cost_level,
|
269 |
+
edge_cost_level,
|
270 |
+
alias_convert_costs,
|
271 |
+
s_init_np=None,
|
272 |
+
verbose=True):
|
273 |
+
"""
|
274 |
+
Call the solver with serialized arguments.
|
275 |
+
"""
|
276 |
+
|
277 |
+
time.time()
|
278 |
+
|
279 |
+
for x in [
|
280 |
+
strategies_len, edge_pairs, compute_costs, communication_costs,
|
281 |
+
memory_costs, resharding_costs, node_sharding_weights,
|
282 |
+
edge_resharding_weights
|
283 |
+
]:
|
284 |
+
assert isinstance(x, np.ndarray)
|
285 |
+
assert len(strategies_len) == node_nums, "strategies_len"
|
286 |
+
|
287 |
+
def get_non_zero_index(binary_vector):
|
288 |
+
"""
|
289 |
+
Get the index of non-zero item in a vector.
|
290 |
+
"""
|
291 |
+
ct = 0
|
292 |
+
ret = None
|
293 |
+
for i, elem in enumerate(binary_vector):
|
294 |
+
if pulp.value(elem):
|
295 |
+
ret = i
|
296 |
+
ct += 1
|
297 |
+
|
298 |
+
assert ct == 1
|
299 |
+
return ret
|
300 |
+
|
301 |
+
# 0. Unpack flatten numpy arrays
|
302 |
+
s_follow = following_nodes
|
303 |
+
s_alias = alias_set
|
304 |
+
|
305 |
+
E = edge_pairs.reshape((-1, 2)) # noqa
|
306 |
+
r = []
|
307 |
+
pt = 0
|
308 |
+
edge_set = set()
|
309 |
+
for (i, j) in E:
|
310 |
+
prod_length = strategies_len[i] * strategies_len[j]
|
311 |
+
|
312 |
+
if (i, j) in edge_set:
|
313 |
+
raise ValueError(f"Duplicated edges: {(i, j)}")
|
314 |
+
|
315 |
+
edge_set.add((i, j))
|
316 |
+
r.append(resharding_costs[pt:pt + prod_length])
|
317 |
+
pt += prod_length
|
318 |
+
assert pt == len(resharding_costs)
|
319 |
+
|
320 |
+
######################
|
321 |
+
# omit alias set now #
|
322 |
+
######################
|
323 |
+
|
324 |
+
# A = alias_set.reshape((-1, 2)) # noqa
|
325 |
+
# for (i, j) in A:
|
326 |
+
# prod_length = strategies_len[i] * strategies_len[j]
|
327 |
+
# v.append(alias_convert_costs[pt:pt + prod_length])
|
328 |
+
# pt += prod_length
|
329 |
+
# assert pt == len(alias_convert_costs)
|
330 |
+
|
331 |
+
# L = [] # noqa
|
332 |
+
# pt = node_nums
|
333 |
+
# for i in range(node_nums):
|
334 |
+
# length = liveness_set[i]
|
335 |
+
# L.append(liveness_set[pt:pt + length])
|
336 |
+
# pt += length
|
337 |
+
# assert pt == len(liveness_set)
|
338 |
+
pt = 0
|
339 |
+
|
340 |
+
c = []
|
341 |
+
d = []
|
342 |
+
m = []
|
343 |
+
peak_m = []
|
344 |
+
pt = 0
|
345 |
+
for i in range(node_nums):
|
346 |
+
length = strategies_len[i]
|
347 |
+
c.append(compute_costs[pt:pt + length])
|
348 |
+
d.append(communication_costs[pt:pt + length])
|
349 |
+
m.append(memory_costs[0][pt:pt + length])
|
350 |
+
peak_m.append(memory_costs[1][pt:pt + length])
|
351 |
+
pt += length
|
352 |
+
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
353 |
+
assert pt == len(
|
354 |
+
communication_costs), f"{pt} == {len(communication_costs)}"
|
355 |
+
assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}"
|
356 |
+
|
357 |
+
# 1. Create variables
|
358 |
+
|
359 |
+
#############################
|
360 |
+
# create variables for node #
|
361 |
+
#############################
|
362 |
+
s = []
|
363 |
+
num_nodes = 0
|
364 |
+
reverse_follow_backpatch = []
|
365 |
+
for i in range(node_nums):
|
366 |
+
if s_follow[i] < 0:
|
367 |
+
if strategies_len[i] == 1:
|
368 |
+
s.append([1])
|
369 |
+
else:
|
370 |
+
if i not in s_alias:
|
371 |
+
num_nodes += 1
|
372 |
+
s.append(
|
373 |
+
LpVariable.matrix(f"s[{i}]",
|
374 |
+
(range(strategies_len[i]), ),
|
375 |
+
cat="Binary"))
|
376 |
+
else:
|
377 |
+
s.append(s[s_alias[i]])
|
378 |
+
else:
|
379 |
+
if s_follow[i] < len(s):
|
380 |
+
s.append(s[s_follow[i]])
|
381 |
+
else:
|
382 |
+
s.append(None)
|
383 |
+
reverse_follow_backpatch.append(i)
|
384 |
+
|
385 |
+
for i in reverse_follow_backpatch:
|
386 |
+
s[i] = s[s_follow[i]]
|
387 |
+
|
388 |
+
#############################
|
389 |
+
# create variables for edge #
|
390 |
+
#############################
|
391 |
+
e = []
|
392 |
+
num_edges = 0
|
393 |
+
map_edge_to_idx = {}
|
394 |
+
for (idx, (i, j)) in enumerate(E):
|
395 |
+
if len(s[i]) == 1:
|
396 |
+
e.append(s[j])
|
397 |
+
elif len(s[j]) == 1:
|
398 |
+
e.append(s[i])
|
399 |
+
else:
|
400 |
+
if i in s_alias and j in s_alias and (
|
401 |
+
s_alias[i], s_alias[j]) in map_edge_to_idx:
|
402 |
+
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
|
403 |
+
else:
|
404 |
+
num_edges += 1
|
405 |
+
e.append(
|
406 |
+
LpVariable.matrix(f"e[{i},{j}]",
|
407 |
+
(range(len(s[i]) * len(s[j])), ),
|
408 |
+
cat="Binary"))
|
409 |
+
assert len(e[idx]) == len(r[idx])
|
410 |
+
map_edge_to_idx[(i, j)] = idx
|
411 |
+
for element in s:
|
412 |
+
assert len(element) > 0
|
413 |
+
# 2. Set initial value
|
414 |
+
######################################
|
415 |
+
# set a initial value for warm start #
|
416 |
+
######################################
|
417 |
+
if s_init_np is not None:
|
418 |
+
s_init = s_init_np.reshape((-1, 3))
|
419 |
+
for (idx, value, fix) in s_init:
|
420 |
+
for i in range(len(s[idx])):
|
421 |
+
s[idx][i].setInitialValue(i == value)
|
422 |
+
if fix:
|
423 |
+
s[idx][i].fixValue()
|
424 |
+
|
425 |
+
# 3. Objective
|
426 |
+
prob = LpProblem("myProblem", LpMinimize)
|
427 |
+
###################################################################
|
428 |
+
# computing the node cost(computing cost and communication cost) #
|
429 |
+
###################################################################
|
430 |
+
obj = 0
|
431 |
+
block_cost_level_dict = {}
|
432 |
+
for i in range(node_nums):
|
433 |
+
assert len(s[i]) == len(c[i])
|
434 |
+
assert len(s[i]) == len(d[i])
|
435 |
+
obj += (lpDot(s[i], c[i]) +
|
436 |
+
lpDot(s[i], d[i])) * node_sharding_weights[i]
|
437 |
+
cost_level = node_cost_level[i]
|
438 |
+
if -1 != cost_level[1]:
|
439 |
+
if cost_level in block_cost_level_dict:
|
440 |
+
block_cost_level_dict[cost_level] += lpDot(
|
441 |
+
s[i], c[i]) + lpDot(s[i], d[i])
|
442 |
+
else:
|
443 |
+
block_cost_level_dict[cost_level] = lpDot(
|
444 |
+
s[i], c[i]) + lpDot(s[i], d[i])
|
445 |
+
|
446 |
+
#############################################
|
447 |
+
# computing the edge cost(resharding cost) #
|
448 |
+
#############################################
|
449 |
+
|
450 |
+
for i in range(len(E)):
|
451 |
+
assert len(e[i]) == len(r[i])
|
452 |
+
obj += lpDot(e[i], r[i]) * edge_resharding_weights[i]
|
453 |
+
cost_level = edge_cost_level[i]
|
454 |
+
if -1 != cost_level[1]:
|
455 |
+
if cost_level in block_cost_level_dict:
|
456 |
+
block_cost_level_dict[cost_level] += lpDot(e[i], r[i])
|
457 |
+
else:
|
458 |
+
block_cost_level_dict[cost_level] = lpDot(e[i], r[i])
|
459 |
+
prob += obj
|
460 |
+
if len(block_cost_level_dict) >= 2:
|
461 |
+
block_cost_levels = [key for key in block_cost_level_dict.keys()]
|
462 |
+
for i in range(len(block_cost_levels)):
|
463 |
+
for j in range(i + 1, len(block_cost_levels)):
|
464 |
+
if block_cost_levels[i][1] > block_cost_levels[j][1]:
|
465 |
+
prob += block_cost_level_dict[
|
466 |
+
block_cost_levels[i]] >= block_cost_level_dict[
|
467 |
+
block_cost_levels[j]] + 1e-6
|
468 |
+
elif block_cost_levels[i][1] < block_cost_levels[j][1]:
|
469 |
+
prob += block_cost_level_dict[
|
470 |
+
block_cost_levels[j]] >= block_cost_level_dict[
|
471 |
+
block_cost_levels[i]] + 1e-6
|
472 |
+
# 4. Constraints
|
473 |
+
# (a). specified by `cat="Binary"`
|
474 |
+
|
475 |
+
# (b)
|
476 |
+
#################################################
|
477 |
+
# make sure each node only choose one strategy #
|
478 |
+
#################################################
|
479 |
+
for i in range(node_nums):
|
480 |
+
if s_follow[i] < 0:
|
481 |
+
prob += lpSum(s[i]) == 1
|
482 |
+
|
483 |
+
# (c)
|
484 |
+
#################################################
|
485 |
+
# force to constrain some nodes have the same sharding specs #
|
486 |
+
#################################################
|
487 |
+
for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items():
|
488 |
+
num_same_spec_nodes = len(same_spec_nodes_id)
|
489 |
+
if num_same_spec_nodes >= 2:
|
490 |
+
src_node_s = s[same_spec_nodes_id[0]]
|
491 |
+
num_specs = len(src_node_s)
|
492 |
+
for i in range(1, num_same_spec_nodes):
|
493 |
+
dst_node_s = s[same_spec_nodes_id[i]]
|
494 |
+
assert len(
|
495 |
+
dst_node_s
|
496 |
+
) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs'
|
497 |
+
for j in range(num_specs):
|
498 |
+
prob += (src_node_s[j] == dst_node_s[j])
|
499 |
+
|
500 |
+
# (c)
|
501 |
+
#################################################
|
502 |
+
# compute memory consumption with liveness set #
|
503 |
+
#################################################
|
504 |
+
if memory_budget > 0:
|
505 |
+
# calculate the constant memory
|
506 |
+
mem = 0
|
507 |
+
for node in liveness_set:
|
508 |
+
if node not in self.node_index_dict:
|
509 |
+
continue
|
510 |
+
node_index = self.node_index_dict[node]
|
511 |
+
mem += lpSum(s[node_index][j] * m[node_index][j]
|
512 |
+
for j in range(len(s[node_index])))
|
513 |
+
# calculate the peak activation memory
|
514 |
+
for node in liveness_set:
|
515 |
+
if node not in self.node_index_dict:
|
516 |
+
continue
|
517 |
+
node_index = self.node_index_dict[node]
|
518 |
+
cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j]
|
519 |
+
for j in range(len(s[node_index])))
|
520 |
+
total_mem = mem + cur_peak_mem
|
521 |
+
prob += total_mem <= memory_budget
|
522 |
+
|
523 |
+
# (d). specified by `cat="Binary"`
|
524 |
+
|
525 |
+
for (idx, (i, j)) in enumerate(E):
|
526 |
+
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
527 |
+
continue
|
528 |
+
|
529 |
+
# (e)
|
530 |
+
prob += lpSum(e[idx]) == 1
|
531 |
+
|
532 |
+
# (f)
|
533 |
+
for row in range(len(s[i])):
|
534 |
+
C = len(s[j]) # noqa
|
535 |
+
prob += lpSum(e[idx][row * C + col]
|
536 |
+
for col in range(0, C)) <= s[i][row]
|
537 |
+
|
538 |
+
# (g)
|
539 |
+
for col in range(len(s[j])):
|
540 |
+
R = len(s[i]) # noqa
|
541 |
+
C = len(s[j]) # noqa
|
542 |
+
prob += lpSum(e[idx][row * C + col]
|
543 |
+
for row in range(0, R)) <= s[j][col]
|
544 |
+
|
545 |
+
if prob.objective.isNumericalConstant():
|
546 |
+
objective = float(pulp.value(prob.objective))
|
547 |
+
status = pulp.LpStatusOptimal
|
548 |
+
else:
|
549 |
+
msg = verbose
|
550 |
+
time_limit = 600
|
551 |
+
solver = pulp.PULP_CBC_CMD(
|
552 |
+
mip=True,
|
553 |
+
msg=msg,
|
554 |
+
timeLimit=time_limit,
|
555 |
+
threads=multiprocessing.cpu_count(),
|
556 |
+
)
|
557 |
+
prob.solve(solver)
|
558 |
+
|
559 |
+
status = prob.status
|
560 |
+
objective = pulp.value(prob.objective)
|
561 |
+
objective = float(
|
562 |
+
objective) if objective is not None else self.INFINITY_COST
|
563 |
+
|
564 |
+
if prob.status in [pulp.LpStatusInfeasible]:
|
565 |
+
objective = self.INFINITY_COST
|
566 |
+
|
567 |
+
# Get and check results
|
568 |
+
s_val = np.full((node_nums, ), -1, dtype=np.int32)
|
569 |
+
for i in range(node_nums):
|
570 |
+
s_val[i] = get_non_zero_index(s[i])
|
571 |
+
|
572 |
+
e_val = np.full((len(E), ), -1, dtype=np.int32)
|
573 |
+
for (idx, (i, j)) in enumerate(E):
|
574 |
+
e_val[idx] = get_non_zero_index(e[idx])
|
575 |
+
i_spec_index = e_val[idx] // len(s[j])
|
576 |
+
j_spec_index = e_val[idx] % len(s[j])
|
577 |
+
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
578 |
+
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
579 |
+
if verbose and r[idx][e_val[idx]] > 0:
|
580 |
+
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
581 |
+
|
582 |
+
self.last_s_val = list(s_val)
|
583 |
+
# self._recover_merged_node_strategy()
|
584 |
+
self.last_objective = objective
|
585 |
+
|
586 |
+
if objective >= self.INFINITY_COST:
|
587 |
+
warnings.warn(
|
588 |
+
f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \
|
589 |
+
f"1. increase memory budget if possible\n" + \
|
590 |
+
f"2. enlarge mesh shape if possible\n" + \
|
591 |
+
f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config")
|
592 |
+
if memory_budget > 0:
|
593 |
+
# calculate the constant memory
|
594 |
+
mem = 0
|
595 |
+
for node in liveness_set:
|
596 |
+
if node not in self.node_index_dict:
|
597 |
+
continue
|
598 |
+
node_index = self.node_index_dict[node]
|
599 |
+
j = self.last_s_val[node_index]
|
600 |
+
mem += m[node_index][j]
|
601 |
+
max_peak_mem = 0
|
602 |
+
for node in liveness_set:
|
603 |
+
if node not in self.node_index_dict:
|
604 |
+
continue
|
605 |
+
node_index = self.node_index_dict[node]
|
606 |
+
j = self.last_s_val[node_index]
|
607 |
+
cur_peak_mem = peak_m[node_index][j]
|
608 |
+
max_peak_mem = max(max_peak_mem, cur_peak_mem)
|
609 |
+
logger.debug(
|
610 |
+
f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}'
|
611 |
+
)
|
612 |
+
|
613 |
+
solution = Solution(self.leaf_strategies, self.last_s_val, e_val,
|
614 |
+
edge_pairs, self.node_index_dict,
|
615 |
+
self.last_objective)
|
616 |
+
return status, solution
|
617 |
+
|
618 |
+
def find_solution(self):
|
619 |
+
"""
|
620 |
+
Call the solver with serialized arguments and handle python errors. Additionally,
|
621 |
+
we could give a serious of solutions with different memory budget.
|
622 |
+
"""
|
623 |
+
if self.solution_numbers == 1:
|
624 |
+
args = self._prepare_data_for_solver()
|
625 |
+
ret = self._call_solver_serialized_args(*args)
|
626 |
+
|
627 |
+
return ret
|
628 |
+
|
629 |
+
origin_memory_budget = self.memory_budget
|
630 |
+
memory_budget_list = [
|
631 |
+
origin_memory_budget * self.memory_increasing_coefficient**i
|
632 |
+
for i in range(self.solution_numbers)
|
633 |
+
]
|
634 |
+
ret_list = []
|
635 |
+
for memory_budget in memory_budget_list:
|
636 |
+
self.memory_budget = memory_budget
|
637 |
+
args = self._prepare_data_for_solver()
|
638 |
+
ret = self._call_solver_serialized_args(*args)
|
639 |
+
ret_list.append(ret)
|
640 |
+
|
641 |
+
return ret_list
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py
ADDED
File without changes
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Activation(Node):
|
8 |
+
|
9 |
+
def _collect_strategies(self, device_mesh):
|
10 |
+
dim_partition_list = []
|
11 |
+
dim_size = len(self.op_data['input0'].shape)
|
12 |
+
dim_partition_list.append({})
|
13 |
+
dim_partition_list.extend(
|
14 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
15 |
+
dim_partition_list.extend(
|
16 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
17 |
+
dim_partition_list.extend(
|
18 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
19 |
+
dim_partition_list.extend(
|
20 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
21 |
+
strategies_vector = StrategiesVector(self)
|
22 |
+
for dim_partition_dict in dim_partition_list:
|
23 |
+
in0_partition_dict = dim_partition_dict
|
24 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
25 |
+
dim_partition_dict_mapping = {
|
26 |
+
"input0": in0_partition_dict,
|
27 |
+
"output0": out_partition_dict,
|
28 |
+
}
|
29 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
30 |
+
dim_partition_dict_mapping, device_mesh)
|
31 |
+
if 0 == len(sharding_spec_mapping):
|
32 |
+
continue
|
33 |
+
name = '{} = <activation op> {}'.format(
|
34 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
35 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
36 |
+
sharding_strategy = self._get_sharding_strategy(
|
37 |
+
name=name,
|
38 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
39 |
+
communication_action_mapping={})
|
40 |
+
strategies_vector.append(sharding_strategy)
|
41 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Assertion(Node):
|
8 |
+
|
9 |
+
def _collect_strategies(self, device_mesh):
|
10 |
+
predecessor = self.predecessor_nodes[0] # one input for softmax node
|
11 |
+
strategies_vector = StrategiesVector(self)
|
12 |
+
for idx, strategy in enumerate(predecessor.strategies_vector):
|
13 |
+
global_input_name = self.op_data[
|
14 |
+
'input0'].name # current node's local name input0 -> global name xxx
|
15 |
+
prenode_local_name = predecessor.global_to_local_op_name[
|
16 |
+
global_input_name] # global name xxx -> pre node local output name
|
17 |
+
dim_partition_dict = copy.deepcopy(
|
18 |
+
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
19 |
+
in0_partition_dict = dim_partition_dict
|
20 |
+
dim_partition_dict_mapping = {
|
21 |
+
"input0": in0_partition_dict,
|
22 |
+
}
|
23 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
24 |
+
dim_partition_dict_mapping, device_mesh)
|
25 |
+
if 0 == len(sharding_spec_mapping):
|
26 |
+
return strategies_vector
|
27 |
+
name = '<assertion> {}'.format(
|
28 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
29 |
+
sharding_strategy = self._get_sharding_strategy(
|
30 |
+
name=name,
|
31 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
32 |
+
communication_action_mapping={})
|
33 |
+
strategies_vector.append(sharding_strategy)
|
34 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Cast(Node):
|
8 |
+
|
9 |
+
def _collect_strategies(self, device_mesh):
|
10 |
+
dim_partition_list = []
|
11 |
+
dim_size = len(self.op_data['input0'].shape)
|
12 |
+
dim_partition_list.append({})
|
13 |
+
dim_partition_list.extend(
|
14 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
15 |
+
dim_partition_list.extend(
|
16 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
17 |
+
dim_partition_list.extend(
|
18 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
19 |
+
dim_partition_list.extend(
|
20 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
21 |
+
strategies_vector = StrategiesVector(self)
|
22 |
+
for dim_partition_dict in dim_partition_list:
|
23 |
+
in0_partition_dict = dim_partition_dict
|
24 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
25 |
+
dim_partition_dict_mapping = {
|
26 |
+
"input0": in0_partition_dict,
|
27 |
+
"output0": out_partition_dict,
|
28 |
+
}
|
29 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
30 |
+
dim_partition_dict_mapping, device_mesh)
|
31 |
+
if 0 == len(sharding_spec_mapping):
|
32 |
+
continue
|
33 |
+
name = '{} = <cast op> {}'.format(
|
34 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
35 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
36 |
+
sharding_strategy = self._get_sharding_strategy(
|
37 |
+
name=name,
|
38 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
39 |
+
communication_action_mapping={})
|
40 |
+
strategies_vector.append(sharding_strategy)
|
41 |
+
|
42 |
+
return strategies_vector
|
43 |
+
|
44 |
+
def _update_memory_cost(self, strategies):
|
45 |
+
pass
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
'CommSpec',
|
3 |
+
]
|
4 |
+
|
5 |
+
|
6 |
+
class CommSpec:
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
comm_pattern,
|
10 |
+
sharding_spec,
|
11 |
+
gather_dim=None,
|
12 |
+
shard_dim=None,
|
13 |
+
logical_process_axis=None,
|
14 |
+
mix_gather=False,
|
15 |
+
forward_only=True):
|
16 |
+
self.comm_pattern = comm_pattern
|
17 |
+
self.sharding_spec = sharding_spec
|
18 |
+
self.gather_dim = gather_dim
|
19 |
+
self.shard_dim = shard_dim
|
20 |
+
self.logical_process_axis = logical_process_axis
|
21 |
+
self.device_mesh = self.sharding_spec.device_mesh
|
22 |
+
self.mix_gather = mix_gather
|
23 |
+
self.forward_only = forward_only
|
24 |
+
if self.gather_dim:
|
25 |
+
assert len(self.gather_dim) == len(
|
26 |
+
self.logical_process_axis
|
27 |
+
), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}'
|
28 |
+
if self.shard_dim:
|
29 |
+
assert len(self.shard_dim) == len(
|
30 |
+
self.logical_process_axis
|
31 |
+
), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}'
|
32 |
+
if self.gather_dim and self.shard_dim:
|
33 |
+
assert len(self.shard_dim) == len(
|
34 |
+
self.gather_dim
|
35 |
+
), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}'
|
36 |
+
|
37 |
+
def get_comm_cost(self):
|
38 |
+
'''
|
39 |
+
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
|
40 |
+
compute the communication cost.
|
41 |
+
For shard operation, it is an on-chip operation, so the communication cost is zero.
|
42 |
+
'''
|
43 |
+
comm_size = self.sharding_spec.get_sharded_size_per_device()
|
44 |
+
dtype = self.sharding_spec.dtype
|
45 |
+
|
46 |
+
# reduce list_of_list to list
|
47 |
+
comm_dims = sum(self.logical_process_axis, [])
|
48 |
+
comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern,
|
49 |
+
comm_dims, comm_size,
|
50 |
+
dtype)
|
51 |
+
return comm_cost
|
52 |
+
|
53 |
+
def get_mem_cost(self):
|
54 |
+
return self.device_mesh.shape_consistency_manager.mem_cost([self])
|
55 |
+
|
56 |
+
def get_max_mem_cost(self):
|
57 |
+
return self.device_mesh.shape_consistency_manager.mem_cost(
|
58 |
+
[self], mem_pattern='max')
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Concatenation(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
13 |
+
self.axis = layer.as_trt().axis
|
14 |
+
batch_dims.remove(self.axis)
|
15 |
+
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
16 |
+
layer.to_base_class()
|
17 |
+
|
18 |
+
def _collect_strategies(self, device_mesh):
|
19 |
+
dim_partition_list = []
|
20 |
+
dim_size = len(self.op_data['output0'].shape)
|
21 |
+
dim_partition_list.append({})
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
24 |
+
dim_partition_list.extend(
|
25 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
26 |
+
dim_partition_list.extend(
|
27 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
28 |
+
dim_partition_list.extend(
|
29 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
30 |
+
|
31 |
+
dim_partition_dict_mapping = {}
|
32 |
+
strategies_vector = StrategiesVector(self)
|
33 |
+
for dim_partition_dict in dim_partition_list:
|
34 |
+
if self.axis in dim_partition_dict:
|
35 |
+
dim_partition_dict.pop(self.axis)
|
36 |
+
for idx in range(self.num_inputs):
|
37 |
+
in_partition_dict = copy.deepcopy(dim_partition_dict)
|
38 |
+
dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict
|
39 |
+
out_partition_dict = dim_partition_dict
|
40 |
+
dim_partition_dict_mapping['output0'] = out_partition_dict
|
41 |
+
|
42 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
43 |
+
dim_partition_dict_mapping, device_mesh)
|
44 |
+
if 0 == len(sharding_spec_mapping):
|
45 |
+
continue
|
46 |
+
name = '{} = <concate along dim {}> {}'.format(
|
47 |
+
sharding_spec_mapping['output0'].sharding_sequence, self.axis, [
|
48 |
+
sharding_spec_mapping[f'input{idx}'].sharding_sequence
|
49 |
+
for idx in range(self.num_inputs)
|
50 |
+
])
|
51 |
+
sharding_strategy = self._get_sharding_strategy(
|
52 |
+
name=name,
|
53 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
54 |
+
communication_action_mapping={})
|
55 |
+
strategies_vector.append(sharding_strategy)
|
56 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class Constant(Node):
|
6 |
+
|
7 |
+
def _update_memory_cost(self, strategies):
|
8 |
+
super()._update_memory_cost(strategies)
|
9 |
+
for strategy in strategies:
|
10 |
+
strategy.inout_memory_footprint = 0.0
|
11 |
+
strategy.peak_memory_footprint = 0.0
|
12 |
+
strategy.const_memory_footprint = strategy.sharding_specs[
|
13 |
+
'output0'].get_max_sharded_size_per_device()
|
14 |
+
|
15 |
+
def _collect_strategies(self, device_mesh):
|
16 |
+
dim_partition_list = []
|
17 |
+
dim_size = len(self.op_data['output0'].shape)
|
18 |
+
dim_partition_list.append({})
|
19 |
+
dim_partition_list.extend(
|
20 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
21 |
+
dim_partition_list.extend(
|
22 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
23 |
+
dim_partition_list.extend(
|
24 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
25 |
+
dim_partition_list.extend(
|
26 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
27 |
+
|
28 |
+
strategies_vector = StrategiesVector(self)
|
29 |
+
for dim_partition_dict in dim_partition_list:
|
30 |
+
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
31 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
32 |
+
dim_partition_dict_mapping, device_mesh)
|
33 |
+
if 0 == len(sharding_spec_mapping):
|
34 |
+
continue
|
35 |
+
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
36 |
+
sharding_strategy = self._get_sharding_strategy(
|
37 |
+
name=f'constant-op {sharding_seq}',
|
38 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
39 |
+
communication_action_mapping={})
|
40 |
+
strategies_vector.append(sharding_strategy)
|
41 |
+
|
42 |
+
return strategies_vector
|
43 |
+
|
44 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
45 |
+
return 0.0
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class ElementWise(Node):
|
6 |
+
|
7 |
+
def __init__(self, layer):
|
8 |
+
super().__init__(layer)
|
9 |
+
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
10 |
+
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
11 |
+
|
12 |
+
def _collect_strategies(self, device_mesh):
|
13 |
+
dim_partition_list = []
|
14 |
+
dim_size = len(self.op_data['output0'].shape)
|
15 |
+
dim_partition_list.append({})
|
16 |
+
dim_partition_list.extend(
|
17 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
18 |
+
dim_partition_list.extend(
|
19 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
20 |
+
dim_partition_list.extend(
|
21 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
24 |
+
strategies_vector = StrategiesVector(self)
|
25 |
+
for dim_partition_dict in dim_partition_list:
|
26 |
+
in0_partition_dict = self._recover_bcast_partition_dict(
|
27 |
+
dim_partition_dict, self.op_data['input0'])
|
28 |
+
in1_partition_dict = self._recover_bcast_partition_dict(
|
29 |
+
dim_partition_dict, self.op_data['input1'])
|
30 |
+
out_partition_dict = dim_partition_dict
|
31 |
+
dim_partition_dict_mapping = {
|
32 |
+
"input0": in0_partition_dict,
|
33 |
+
"input1": in1_partition_dict,
|
34 |
+
"output0": out_partition_dict,
|
35 |
+
}
|
36 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
37 |
+
dim_partition_dict_mapping, device_mesh)
|
38 |
+
if 0 == len(sharding_spec_mapping):
|
39 |
+
continue
|
40 |
+
name = '{} = {} <elementwise> {}'.format(
|
41 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
42 |
+
sharding_spec_mapping['input0'].sharding_sequence,
|
43 |
+
sharding_spec_mapping['input1'].sharding_sequence)
|
44 |
+
sharding_strategy = self._get_sharding_strategy(
|
45 |
+
name=name,
|
46 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
47 |
+
communication_action_mapping={})
|
48 |
+
strategies_vector.append(sharding_strategy)
|
49 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorrt as trt
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Fill(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
self.operation = layer.as_trt().operation
|
13 |
+
layer.to_base_class()
|
14 |
+
|
15 |
+
def _collect_strategies(self, device_mesh):
|
16 |
+
dim_partition_list = []
|
17 |
+
dim_size = len(self.op_data['output0'].shape)
|
18 |
+
dim_partition_list.append({})
|
19 |
+
if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE:
|
20 |
+
dim_partition_list.extend(
|
21 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
24 |
+
dim_partition_list.extend(
|
25 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
26 |
+
dim_partition_list.extend(
|
27 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
28 |
+
|
29 |
+
strategies_vector = StrategiesVector(self)
|
30 |
+
for dim_partition_dict in dim_partition_list:
|
31 |
+
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
32 |
+
for i in range(self.num_inputs):
|
33 |
+
dim_partition_dict_mapping[f'input{i}'] = {}
|
34 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
35 |
+
dim_partition_dict_mapping, device_mesh)
|
36 |
+
if 0 == len(sharding_spec_mapping):
|
37 |
+
continue
|
38 |
+
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
39 |
+
sharding_strategy = self._get_sharding_strategy(
|
40 |
+
name=f'fill-op {sharding_seq}',
|
41 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
42 |
+
communication_action_mapping={})
|
43 |
+
strategies_vector.append(sharding_strategy)
|
44 |
+
|
45 |
+
return strategies_vector
|
46 |
+
|
47 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
48 |
+
updated_layer_attrs = {}
|
49 |
+
updated_input_values = {}
|
50 |
+
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
|
51 |
+
)
|
52 |
+
if self.layer.num_inputs >= 1:
|
53 |
+
updated_input_values[0] = shape
|
54 |
+
else:
|
55 |
+
updated_layer_attrs['shape'] = shape
|
56 |
+
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
57 |
+
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
58 |
+
device_mesh)
|
59 |
+
return elapsed_time
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import tensorrt as trt
|
4 |
+
|
5 |
+
from .comm_spec import CommSpec
|
6 |
+
from .node import Node
|
7 |
+
from .sharding_spec import DimSpec
|
8 |
+
from .sharding_strategy import StrategiesVector
|
9 |
+
|
10 |
+
|
11 |
+
class Gather(Node):
|
12 |
+
|
13 |
+
def __init__(self, layer):
|
14 |
+
super().__init__(layer)
|
15 |
+
layer.to_subclass()
|
16 |
+
self.mode = layer.as_trt().mode
|
17 |
+
self.axis = layer.as_trt().axis
|
18 |
+
self.num_elementwise_dims = layer.as_trt().num_elementwise_dims
|
19 |
+
self.input_id = 0
|
20 |
+
self.indice_id = 1
|
21 |
+
self.support_vocab_tp = False
|
22 |
+
layer.to_base_class()
|
23 |
+
|
24 |
+
def _update_memory_cost(self, strategies):
|
25 |
+
for strategy in strategies:
|
26 |
+
# for gather node, it input0's read = output0's write
|
27 |
+
inout_memory_footprint = (
|
28 |
+
strategy.sharding_specs['output0'].get_sharded_size_per_device(
|
29 |
+
) * 2 +
|
30 |
+
strategy.sharding_specs['input1'].get_sharded_size_per_device())
|
31 |
+
strategy.inout_memory_footprint = inout_memory_footprint
|
32 |
+
strategy.peak_memory_footprint = (
|
33 |
+
strategy.sharding_specs['output0'].
|
34 |
+
get_max_sharded_size_per_device() + strategy.
|
35 |
+
sharding_specs['input0'].get_max_sharded_size_per_device() +
|
36 |
+
strategy.sharding_specs['input1'].
|
37 |
+
get_max_sharded_size_per_device())
|
38 |
+
|
39 |
+
def _collect_strategies(self, device_mesh):
|
40 |
+
if self.mode == trt.GatherMode.DEFAULT:
|
41 |
+
return self._default_gather_strategies(device_mesh)
|
42 |
+
elif self.mode == trt.GatherMode.ELEMENT:
|
43 |
+
return self._element_gather_strategies(device_mesh)
|
44 |
+
elif self.mode == trt.GatherMode.ND:
|
45 |
+
assert 0, 'unsupport gatherND'
|
46 |
+
else:
|
47 |
+
assert 0, f'unsupport gather mode {self.mode}'
|
48 |
+
|
49 |
+
def _element_gather_strategies(self, device_mesh):
|
50 |
+
dim_partition_list = []
|
51 |
+
dim_size = len(self.op_data['output0'].shape)
|
52 |
+
dim_partition_list.append({})
|
53 |
+
dim_partition_list.extend(
|
54 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
55 |
+
dim_partition_list.extend(
|
56 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
57 |
+
dim_partition_list.extend(
|
58 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
59 |
+
dim_partition_list.extend(
|
60 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
61 |
+
|
62 |
+
strategies_vector = StrategiesVector(self)
|
63 |
+
for dim_partition_dict in dim_partition_list:
|
64 |
+
if self.axis in dim_partition_dict:
|
65 |
+
dim_partition_dict.pop(self.axis)
|
66 |
+
|
67 |
+
dim_partition_dict_mapping = {
|
68 |
+
'input0': dim_partition_dict,
|
69 |
+
'input1': copy.deepcopy(dim_partition_dict),
|
70 |
+
'output0': copy.deepcopy(dim_partition_dict),
|
71 |
+
}
|
72 |
+
|
73 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
74 |
+
dim_partition_dict_mapping, device_mesh)
|
75 |
+
if 0 == len(sharding_spec_mapping):
|
76 |
+
continue
|
77 |
+
name = '{} = {} <element gather op axis {}> {}'.format(
|
78 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
79 |
+
sharding_spec_mapping['input0'].sharding_sequence, self.axis,
|
80 |
+
sharding_spec_mapping['input1'].sharding_sequence)
|
81 |
+
sharding_strategy = self._get_sharding_strategy(
|
82 |
+
name=name,
|
83 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
84 |
+
communication_action_mapping={})
|
85 |
+
strategies_vector.append(sharding_strategy)
|
86 |
+
|
87 |
+
return strategies_vector
|
88 |
+
|
89 |
+
# for plugin, indice is input0, and weight is input1, which is different from gather node
|
90 |
+
def _default_gather_strategies(self, device_mesh):
|
91 |
+
|
92 |
+
def add_sharding_strategy(dim_partition_dict_mapping,
|
93 |
+
vocab_tp_dim=None):
|
94 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
95 |
+
dim_partition_dict_mapping, device_mesh)
|
96 |
+
if len(sharding_spec_mapping) > 0:
|
97 |
+
name = '{} = {} <default gather op axis {}, num_elementwise_dims {}> {}'.format(
|
98 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
99 |
+
sharding_spec_mapping['input0'].sharding_sequence,
|
100 |
+
self.axis, self.num_elementwise_dims,
|
101 |
+
sharding_spec_mapping['input1'].sharding_sequence)
|
102 |
+
communication_action_mapping = {}
|
103 |
+
if vocab_tp_dim is not None:
|
104 |
+
name += f'_allreduce{DimSpec(vocab_tp_dim)}'
|
105 |
+
output0_comm_action = CommSpec(
|
106 |
+
comm_pattern='all_reduce',
|
107 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
108 |
+
logical_process_axis=[vocab_tp_dim],
|
109 |
+
)
|
110 |
+
communication_action_mapping[
|
111 |
+
'output0'] = output0_comm_action
|
112 |
+
sharding_strategy = self._get_sharding_strategy(
|
113 |
+
name=name,
|
114 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
115 |
+
communication_action_mapping=communication_action_mapping)
|
116 |
+
strategies_vector.append(sharding_strategy)
|
117 |
+
|
118 |
+
input_id, indice_id = self.input_id, self.indice_id
|
119 |
+
strategies_vector = StrategiesVector(self)
|
120 |
+
input_size = len(self.op_data[f'input{input_id}'].shape)
|
121 |
+
indice_size = len(self.op_data[f'input{indice_id}'].shape)
|
122 |
+
output_dim = input_size + indice_size - 1 - self.num_elementwise_dims
|
123 |
+
for strategy in self.predecessor_nodes[input_id].strategies_vector:
|
124 |
+
# current node's local name input0 -> global name xxx
|
125 |
+
global_input_name = self.op_data[f'input{input_id}'].name
|
126 |
+
# global name xxx -> pre node local output name
|
127 |
+
prenode_local_name = self.predecessor_nodes[
|
128 |
+
input_id].global_to_local_op_name[global_input_name]
|
129 |
+
input_dim_partition_dict = copy.deepcopy(
|
130 |
+
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
131 |
+
|
132 |
+
vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None)
|
133 |
+
|
134 |
+
input_mesh_dims = []
|
135 |
+
for dim, mesh_dims in input_dim_partition_dict.items():
|
136 |
+
input_mesh_dims += mesh_dims
|
137 |
+
input_mesh_dims = set(input_mesh_dims)
|
138 |
+
|
139 |
+
for idx_strategy in self.predecessor_nodes[
|
140 |
+
indice_id].strategies_vector:
|
141 |
+
# current node's local name input0 -> global name xxx
|
142 |
+
global_indice_name = self.op_data[f'input{indice_id}'].name
|
143 |
+
# global name xxx -> pre node local output name
|
144 |
+
prenode_local_name = self.predecessor_nodes[
|
145 |
+
indice_id].global_to_local_op_name[global_indice_name]
|
146 |
+
indice_dim_partition_dict = copy.deepcopy(
|
147 |
+
idx_strategy.sharding_specs[prenode_local_name].
|
148 |
+
dim_partition_dict)
|
149 |
+
|
150 |
+
for dim, indice_mesh_dims in idx_strategy.sharding_specs[
|
151 |
+
prenode_local_name].dim_partition_dict.items():
|
152 |
+
for indice_mesh_dim in indice_mesh_dims:
|
153 |
+
if indice_mesh_dim in input_mesh_dims:
|
154 |
+
indice_dim_partition_dict.pop(dim)
|
155 |
+
break
|
156 |
+
|
157 |
+
out_partition_dict = {}
|
158 |
+
|
159 |
+
for dim in range(output_dim):
|
160 |
+
if dim < self.axis:
|
161 |
+
if dim in input_dim_partition_dict:
|
162 |
+
out_partition_dict[dim] = \
|
163 |
+
input_dim_partition_dict[dim]
|
164 |
+
elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims:
|
165 |
+
indice_dim = dim - self.axis + self.num_elementwise_dims
|
166 |
+
if indice_dim in indice_dim_partition_dict:
|
167 |
+
out_partition_dict[dim] = \
|
168 |
+
indice_dim_partition_dict[indice_dim]
|
169 |
+
else:
|
170 |
+
input_dim = dim - (indice_size -
|
171 |
+
self.num_elementwise_dims) + 1
|
172 |
+
if input_dim in input_dim_partition_dict:
|
173 |
+
out_partition_dict[dim] = \
|
174 |
+
input_dim_partition_dict[input_dim]
|
175 |
+
|
176 |
+
dim_partition_dict_mapping = {
|
177 |
+
f"input{input_id}": input_dim_partition_dict,
|
178 |
+
f"input{indice_id}": indice_dim_partition_dict,
|
179 |
+
"output0": out_partition_dict,
|
180 |
+
}
|
181 |
+
add_sharding_strategy(dim_partition_dict_mapping)
|
182 |
+
|
183 |
+
if self.support_vocab_tp and vocab_tp_dim is not None:
|
184 |
+
vocab_tp_dim_partition_dict = {
|
185 |
+
**input_dim_partition_dict,
|
186 |
+
self.axis: vocab_tp_dim,
|
187 |
+
}
|
188 |
+
dim_partition_dict_mapping = {
|
189 |
+
f"input{input_id}": vocab_tp_dim_partition_dict,
|
190 |
+
f"input{indice_id}": indice_dim_partition_dict,
|
191 |
+
"output0": out_partition_dict,
|
192 |
+
}
|
193 |
+
add_sharding_strategy(dim_partition_dict_mapping,
|
194 |
+
vocab_tp_dim)
|
195 |
+
|
196 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Identity(Node):
|
8 |
+
|
9 |
+
def _update_memory_cost(self, strategies):
|
10 |
+
if not self.is_fake:
|
11 |
+
super()._update_memory_cost(strategies)
|
12 |
+
else:
|
13 |
+
# fake nodes for building block/PP connection
|
14 |
+
pass
|
15 |
+
|
16 |
+
def _collect_strategies(self, device_mesh):
|
17 |
+
dim_partition_list = []
|
18 |
+
dim_size = len(self.op_data['input0'].shape)
|
19 |
+
dim_partition_list.append({})
|
20 |
+
dim_partition_list.extend(
|
21 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
24 |
+
dim_partition_list.extend(
|
25 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
26 |
+
dim_partition_list.extend(
|
27 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
28 |
+
strategies_vector = StrategiesVector(self)
|
29 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
30 |
+
for dim_partition_dict in dim_partition_list:
|
31 |
+
in0_partition_dict = dim_partition_dict
|
32 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
33 |
+
dim_partition_dict_mapping = {
|
34 |
+
"input0": in0_partition_dict,
|
35 |
+
"output0": out_partition_dict,
|
36 |
+
}
|
37 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
38 |
+
dim_partition_dict_mapping, device_mesh)
|
39 |
+
if 0 == len(sharding_spec_mapping):
|
40 |
+
continue
|
41 |
+
name = '{} = <identity op> {}'.format(
|
42 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
43 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
44 |
+
sharding_strategy = self._get_sharding_strategy(
|
45 |
+
name=name,
|
46 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
47 |
+
communication_action_mapping={})
|
48 |
+
strategies_vector.append(sharding_strategy)
|
49 |
+
return strategies_vector
|
50 |
+
|
51 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
52 |
+
# if same spec id is not 0, identify node is used as same spec id node
|
53 |
+
if self.same_spec_id == -1:
|
54 |
+
return super()._profile_sharding_cost(strategy, device_mesh)
|
55 |
+
else:
|
56 |
+
return 0.0
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class InputNode(Node):
|
6 |
+
|
7 |
+
def _update_memory_cost(self, strategies):
|
8 |
+
for strategy in strategies:
|
9 |
+
if not self.no_memory_footprint:
|
10 |
+
strategy.const_memory_footprint = strategy.sharding_specs[
|
11 |
+
'output0'].get_max_sharded_size_per_device()
|
12 |
+
|
13 |
+
def __init__(self, tensor):
|
14 |
+
self._layer = None
|
15 |
+
self.is_shape_io = False
|
16 |
+
self._inputs = []
|
17 |
+
self._outputs = []
|
18 |
+
self.predecessor_nodes = []
|
19 |
+
self.predecessor_nodes_out_index = {}
|
20 |
+
self.successor_nodes = []
|
21 |
+
self.op_data = {}
|
22 |
+
self.global_to_local_op_name = {}
|
23 |
+
self.is_replicated = tensor.attrs.get("is_replicated", False)
|
24 |
+
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
|
25 |
+
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
|
26 |
+
False)
|
27 |
+
self.building_block_id = -1
|
28 |
+
self.cost_level = -1
|
29 |
+
self.stage_type = None
|
30 |
+
self.in_start_block = None
|
31 |
+
self.in_end_block = None
|
32 |
+
self.in_slowest_block = None
|
33 |
+
output = tensor.copy()
|
34 |
+
self._outputs.append(output)
|
35 |
+
self.op_data['output0'] = output
|
36 |
+
self.global_to_local_op_name[output.name] = 'output0'
|
37 |
+
|
38 |
+
self.sharding_weight = 1.0
|
39 |
+
self.resharding_weight = 1.0
|
40 |
+
self.pipeline_weight = 0
|
41 |
+
self.node_name = tensor.name
|
42 |
+
self.node_type = 'input_node'
|
43 |
+
self.num_inputs = 0
|
44 |
+
self.num_outputs = 1
|
45 |
+
self.dtype = tensor.dtype
|
46 |
+
self.strategies_vector = []
|
47 |
+
self.node_runtime_profiler = None
|
48 |
+
|
49 |
+
def _collect_strategies(self, device_mesh):
|
50 |
+
dim_partition_list = []
|
51 |
+
dim_size = len(self.op_data['output0'].shape)
|
52 |
+
dim_partition_list.append({})
|
53 |
+
dim_partition_list.extend(
|
54 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
55 |
+
dim_partition_list.extend(
|
56 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
57 |
+
dim_partition_list.extend(
|
58 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
59 |
+
dim_partition_list.extend(
|
60 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
61 |
+
|
62 |
+
strategies_vector = StrategiesVector(self)
|
63 |
+
for dim_partition_dict in dim_partition_list:
|
64 |
+
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
65 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
66 |
+
dim_partition_dict_mapping, device_mesh)
|
67 |
+
if 0 == len(sharding_spec_mapping):
|
68 |
+
continue
|
69 |
+
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
70 |
+
sharding_strategy = self._get_sharding_strategy(
|
71 |
+
name=f'input-op {sharding_seq}',
|
72 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
73 |
+
communication_action_mapping={})
|
74 |
+
strategies_vector.append(sharding_strategy)
|
75 |
+
|
76 |
+
return strategies_vector
|
77 |
+
|
78 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
79 |
+
return 0.0
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py
ADDED
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import operator
|
3 |
+
from functools import reduce
|
4 |
+
|
5 |
+
import tensorrt as trt
|
6 |
+
|
7 |
+
from ..device_mesh import LogicalDeviceMesh
|
8 |
+
from ..utils import get_builder_flags
|
9 |
+
from .comm_spec import CommSpec
|
10 |
+
from .node import Node
|
11 |
+
from .sharding_spec import DimSpec
|
12 |
+
from .sharding_strategy import StrategiesVector
|
13 |
+
|
14 |
+
|
15 |
+
class MatrixMultiply(Node):
|
16 |
+
|
17 |
+
def __init__(self, layer):
|
18 |
+
super().__init__(layer)
|
19 |
+
layer.to_subclass()
|
20 |
+
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
|
21 |
+
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
22 |
+
self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE
|
23 |
+
self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE
|
24 |
+
self.num_out_dims = len(self.get_output(0).shape)
|
25 |
+
dtypes_str = [
|
26 |
+
self.get_input(0).dtype_str,
|
27 |
+
self.get_input(1).dtype_str,
|
28 |
+
self.get_output(0).dtype_str
|
29 |
+
]
|
30 |
+
dtypes_size = [
|
31 |
+
self.get_input(0).dtype_size,
|
32 |
+
self.get_input(1).dtype_size,
|
33 |
+
self.get_output(0).dtype_size
|
34 |
+
]
|
35 |
+
min_idx = dtypes_size.index(min(dtypes_size))
|
36 |
+
self.dtype = dtypes_str[min_idx]
|
37 |
+
layer.to_base_class()
|
38 |
+
|
39 |
+
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh):
|
40 |
+
in0_split_dim = -1 if self.op0_transpose else -2
|
41 |
+
in1_split_dim = -2 if self.op1_transpose else -1
|
42 |
+
name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = '
|
43 |
+
f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}')
|
44 |
+
dim_partition_dict_mapping = {
|
45 |
+
"input0": {
|
46 |
+
in0_split_dim: mesh_dim_0
|
47 |
+
},
|
48 |
+
"input1": {
|
49 |
+
in1_split_dim: mesh_dim_1
|
50 |
+
},
|
51 |
+
"output0": {
|
52 |
+
-2: mesh_dim_0,
|
53 |
+
-1: mesh_dim_1
|
54 |
+
},
|
55 |
+
}
|
56 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
57 |
+
dim_partition_dict_mapping, device_mesh)
|
58 |
+
if len(sharding_spec_mapping) == 0:
|
59 |
+
return None
|
60 |
+
strategy = self._get_sharding_strategy(name = name, \
|
61 |
+
sharding_spec_mapping = sharding_spec_mapping, \
|
62 |
+
communication_action_mapping = {})
|
63 |
+
return strategy
|
64 |
+
|
65 |
+
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
|
66 |
+
device_mesh):
|
67 |
+
# handle the case SR = SS x SR
|
68 |
+
name = (
|
69 |
+
f'{DimSpec(mesh_dim_0)}R = '
|
70 |
+
f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R'
|
71 |
+
f'_allreduce{DimSpec(mesh_dim_1)}')
|
72 |
+
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
|
73 |
+
in1_split_dim = -1 if self.op1_transpose else -2
|
74 |
+
# get sharding spec mapping
|
75 |
+
dim_partition_dict_mapping = {
|
76 |
+
"input0": {
|
77 |
+
in0_split_dim[0]: mesh_dim_0,
|
78 |
+
in0_split_dim[1]: mesh_dim_1
|
79 |
+
},
|
80 |
+
"input1": {
|
81 |
+
in1_split_dim: mesh_dim_1
|
82 |
+
},
|
83 |
+
"output0": {
|
84 |
+
-2: mesh_dim_0
|
85 |
+
},
|
86 |
+
}
|
87 |
+
|
88 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
89 |
+
dim_partition_dict_mapping, device_mesh)
|
90 |
+
if len(sharding_spec_mapping) == 0:
|
91 |
+
return None
|
92 |
+
# get communication action mapping
|
93 |
+
communication_action_mapping = {}
|
94 |
+
output0_comm_action = CommSpec(
|
95 |
+
comm_pattern='all_reduce',
|
96 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
97 |
+
logical_process_axis=[mesh_dim_1],
|
98 |
+
)
|
99 |
+
communication_action_mapping['output0'] = output0_comm_action
|
100 |
+
return self._get_sharding_strategy(
|
101 |
+
name=name,
|
102 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
103 |
+
communication_action_mapping=communication_action_mapping)
|
104 |
+
|
105 |
+
def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec,
|
106 |
+
dim_partition_dict_mapping, device_mesh):
|
107 |
+
output0_comm_action = CommSpec(
|
108 |
+
comm_pattern='reduce_scatter',
|
109 |
+
sharding_spec=src_spec,
|
110 |
+
shard_dim=[rs_dim],
|
111 |
+
logical_process_axis=[rs_mesh_dim],
|
112 |
+
)
|
113 |
+
rs_out_partition_dict_mapping = copy.deepcopy(
|
114 |
+
dim_partition_dict_mapping)
|
115 |
+
rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim
|
116 |
+
rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
117 |
+
rs_out_partition_dict_mapping, device_mesh)
|
118 |
+
if len(rs_out_sharding_spec_mapping) == 0:
|
119 |
+
return None
|
120 |
+
|
121 |
+
communication_action_mapping = {}
|
122 |
+
communication_action_mapping['output0'] = output0_comm_action
|
123 |
+
return self._get_sharding_strategy(
|
124 |
+
name=name,
|
125 |
+
sharding_spec_mapping=rs_out_sharding_spec_mapping,
|
126 |
+
communication_action_mapping=communication_action_mapping)
|
127 |
+
|
128 |
+
def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
|
129 |
+
device_mesh):
|
130 |
+
# handle the case SS = SS x SR -> reduce_scatter
|
131 |
+
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
|
132 |
+
in1_split_dim = -1 if self.op1_transpose else -2
|
133 |
+
# get sharding spec mapping
|
134 |
+
dim_partition_dict_mapping = {
|
135 |
+
"input0": {
|
136 |
+
in0_split_dim[0]: mesh_dim_0,
|
137 |
+
in0_split_dim[1]: mesh_dim_1
|
138 |
+
},
|
139 |
+
"input1": {
|
140 |
+
in1_split_dim: mesh_dim_1
|
141 |
+
},
|
142 |
+
"output0": {
|
143 |
+
-2: mesh_dim_0,
|
144 |
+
},
|
145 |
+
}
|
146 |
+
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
147 |
+
dim_partition_dict_mapping, device_mesh)
|
148 |
+
if len(mm_out_sharding_spec_mapping) == 0:
|
149 |
+
return []
|
150 |
+
strategies = []
|
151 |
+
for rs_dim in range(self.num_out_dims):
|
152 |
+
if rs_dim != self.num_out_dims - 2:
|
153 |
+
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
154 |
+
'R'
|
155 |
+
] * self.num_out_dims, ['R'] * self.num_out_dims
|
156 |
+
name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str(
|
157 |
+
DimSpec(mesh_dim_1))
|
158 |
+
name_in1[-2] = str(DimSpec(mesh_dim_1))
|
159 |
+
name_out0[-2], name_out0[rs_dim] = str(
|
160 |
+
DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1))
|
161 |
+
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
162 |
+
name_in1), ', '.join(name_out0)
|
163 |
+
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
|
164 |
+
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}')
|
165 |
+
ret = self._split_both_contract_rs(
|
166 |
+
name, rs_dim, mesh_dim_1,
|
167 |
+
mm_out_sharding_spec_mapping['output0'],
|
168 |
+
dim_partition_dict_mapping, device_mesh)
|
169 |
+
if ret:
|
170 |
+
strategies.append(ret)
|
171 |
+
return strategies
|
172 |
+
|
173 |
+
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
|
174 |
+
device_mesh):
|
175 |
+
name = (
|
176 |
+
f'R{DimSpec(mesh_dim_1)} = '
|
177 |
+
f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}'
|
178 |
+
f'_allreduce{DimSpec(mesh_dim_0)}')
|
179 |
+
in0_split_dim = -2 if self.op0_transpose else -1
|
180 |
+
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
|
181 |
+
# get sharding specs
|
182 |
+
dim_partition_dict_mapping = {
|
183 |
+
"input0": {
|
184 |
+
in0_split_dim: mesh_dim_0
|
185 |
+
},
|
186 |
+
"input1": {
|
187 |
+
in1_split_dim[0]: mesh_dim_0,
|
188 |
+
in1_split_dim[1]: mesh_dim_1
|
189 |
+
},
|
190 |
+
"output0": {
|
191 |
+
-1: mesh_dim_1
|
192 |
+
},
|
193 |
+
}
|
194 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
195 |
+
dim_partition_dict_mapping, device_mesh)
|
196 |
+
if len(sharding_spec_mapping) == 0:
|
197 |
+
return None
|
198 |
+
# get communication actions
|
199 |
+
communication_action_mapping = {}
|
200 |
+
output0_comm_action = CommSpec(
|
201 |
+
comm_pattern='all_reduce',
|
202 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
203 |
+
logical_process_axis=[mesh_dim_0],
|
204 |
+
)
|
205 |
+
communication_action_mapping['output0'] = output0_comm_action
|
206 |
+
return self._get_sharding_strategy(
|
207 |
+
name=name,
|
208 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
209 |
+
communication_action_mapping=communication_action_mapping)
|
210 |
+
|
211 |
+
def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
|
212 |
+
device_mesh):
|
213 |
+
in0_split_dim = -2 if self.op0_transpose else -1
|
214 |
+
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
|
215 |
+
# get sharding specs
|
216 |
+
dim_partition_dict_mapping = {
|
217 |
+
"input0": {
|
218 |
+
in0_split_dim: mesh_dim_0
|
219 |
+
},
|
220 |
+
"input1": {
|
221 |
+
in1_split_dim[0]: mesh_dim_0,
|
222 |
+
in1_split_dim[1]: mesh_dim_1
|
223 |
+
},
|
224 |
+
"output0": {
|
225 |
+
-1: mesh_dim_1
|
226 |
+
},
|
227 |
+
}
|
228 |
+
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
229 |
+
dim_partition_dict_mapping, device_mesh)
|
230 |
+
if len(mm_out_sharding_spec_mapping) == 0:
|
231 |
+
return []
|
232 |
+
strategies = []
|
233 |
+
for rs_dim in range(self.num_out_dims):
|
234 |
+
if rs_dim != self.num_out_dims - 1:
|
235 |
+
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
236 |
+
'R'
|
237 |
+
] * self.num_out_dims, ['R'] * self.num_out_dims
|
238 |
+
name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str(
|
239 |
+
DimSpec(mesh_dim_1))
|
240 |
+
name_in0[-1] = str(DimSpec(mesh_dim_0))
|
241 |
+
name_out0[-1], name_out0[rs_dim] = str(
|
242 |
+
DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0))
|
243 |
+
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
244 |
+
name_in1), ', '.join(name_out0)
|
245 |
+
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
|
246 |
+
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}')
|
247 |
+
ret = self._split_both_contract_rs(
|
248 |
+
name, rs_dim, mesh_dim_0,
|
249 |
+
mm_out_sharding_spec_mapping['output0'],
|
250 |
+
dim_partition_dict_mapping, device_mesh)
|
251 |
+
if ret:
|
252 |
+
strategies.append(ret)
|
253 |
+
return strategies
|
254 |
+
|
255 |
+
def _recompute_split_both_contract(self, mesh_dim, device_mesh):
|
256 |
+
name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
|
257 |
+
f'_allreduce{DimSpec(mesh_dim)}')
|
258 |
+
in0_split_dim = -2 if self.op0_transpose else -1
|
259 |
+
in1_split_dim = -1 if self.op1_transpose else -2
|
260 |
+
dim_partition_dict_mapping = {
|
261 |
+
"input0": {
|
262 |
+
in0_split_dim: mesh_dim
|
263 |
+
},
|
264 |
+
"input1": {
|
265 |
+
in1_split_dim: mesh_dim
|
266 |
+
},
|
267 |
+
"output0": {},
|
268 |
+
}
|
269 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
270 |
+
dim_partition_dict_mapping, device_mesh)
|
271 |
+
if len(sharding_spec_mapping) == 0:
|
272 |
+
return None
|
273 |
+
|
274 |
+
# get communication action
|
275 |
+
communication_action_mapping = {}
|
276 |
+
output0_comm_action = CommSpec(
|
277 |
+
comm_pattern='all_reduce',
|
278 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
279 |
+
logical_process_axis=[mesh_dim],
|
280 |
+
)
|
281 |
+
communication_action_mapping['output0'] = output0_comm_action
|
282 |
+
return self._get_sharding_strategy(
|
283 |
+
name=name,
|
284 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
285 |
+
communication_action_mapping=communication_action_mapping)
|
286 |
+
|
287 |
+
def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh):
|
288 |
+
name = (f'{DimSpec(mesh_dim)}R = '
|
289 |
+
f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
|
290 |
+
f'_reducescatter0_{DimSpec(mesh_dim)}')
|
291 |
+
in0_split_dim = -2 if self.op0_transpose else -1
|
292 |
+
in1_split_dim = -1 if self.op1_transpose else -2
|
293 |
+
dim_partition_dict_mapping = {
|
294 |
+
"input0": {
|
295 |
+
in0_split_dim: mesh_dim
|
296 |
+
},
|
297 |
+
"input1": {
|
298 |
+
in1_split_dim: mesh_dim
|
299 |
+
},
|
300 |
+
"output0": {},
|
301 |
+
}
|
302 |
+
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
303 |
+
dim_partition_dict_mapping, device_mesh)
|
304 |
+
if len(mm_out_sharding_spec_mapping) == 0:
|
305 |
+
return []
|
306 |
+
|
307 |
+
strategies = []
|
308 |
+
for rs_dim in range(self.num_out_dims):
|
309 |
+
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
310 |
+
'R'
|
311 |
+
] * self.num_out_dims, ['R'] * self.num_out_dims
|
312 |
+
name_in0[-1], name_in1[-2], name_out0[rs_dim] = str(
|
313 |
+
DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str(
|
314 |
+
DimSpec(mesh_dim))
|
315 |
+
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
316 |
+
name_in1), ', '.join(name_out0)
|
317 |
+
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}'
|
318 |
+
ret = self._split_both_contract_rs(
|
319 |
+
name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'],
|
320 |
+
dim_partition_dict_mapping, device_mesh)
|
321 |
+
if ret:
|
322 |
+
strategies.append(ret)
|
323 |
+
return strategies
|
324 |
+
|
325 |
+
def _split_rhs_space_only(self, mesh_dim, device_mesh):
|
326 |
+
name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}'
|
327 |
+
in1_split_dim = -2 if self.op1_transpose else -1
|
328 |
+
# get sharding spec
|
329 |
+
dim_partition_dict_mapping = {
|
330 |
+
"input0": {},
|
331 |
+
"input1": {
|
332 |
+
in1_split_dim: mesh_dim
|
333 |
+
},
|
334 |
+
"output0": {
|
335 |
+
-1: mesh_dim
|
336 |
+
},
|
337 |
+
}
|
338 |
+
# We don't have to do anything special for bias here, because
|
339 |
+
# the bias is already the same sharding spec as the output0.
|
340 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
341 |
+
dim_partition_dict_mapping, device_mesh)
|
342 |
+
if len(sharding_spec_mapping) == 0:
|
343 |
+
return None
|
344 |
+
return self._get_sharding_strategy(
|
345 |
+
name=name,
|
346 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
347 |
+
communication_action_mapping={})
|
348 |
+
|
349 |
+
def _split_lhs_space_only(self, mesh_dim, device_mesh):
|
350 |
+
name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR'
|
351 |
+
in0_split_dim = -1 if self.op0_transpose else -2
|
352 |
+
# get sharding spec
|
353 |
+
dim_partition_dict_mapping = {
|
354 |
+
"input0": {
|
355 |
+
in0_split_dim: mesh_dim
|
356 |
+
},
|
357 |
+
"input1": {},
|
358 |
+
"output0": {
|
359 |
+
-2: mesh_dim
|
360 |
+
},
|
361 |
+
}
|
362 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
363 |
+
dim_partition_dict_mapping, device_mesh)
|
364 |
+
if len(sharding_spec_mapping) == 0:
|
365 |
+
return None
|
366 |
+
return self._get_sharding_strategy(
|
367 |
+
name=name,
|
368 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
369 |
+
communication_action_mapping={})
|
370 |
+
|
371 |
+
def _non_split(self, device_mesh):
|
372 |
+
name = 'RR = RR x RR'
|
373 |
+
# get sharding spec
|
374 |
+
dim_partition_dict_mapping = {
|
375 |
+
"input0": {},
|
376 |
+
"input1": {},
|
377 |
+
"output0": {},
|
378 |
+
}
|
379 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
380 |
+
dim_partition_dict_mapping, device_mesh)
|
381 |
+
if len(sharding_spec_mapping) == 0:
|
382 |
+
return None
|
383 |
+
return self._get_sharding_strategy(
|
384 |
+
name=name,
|
385 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
386 |
+
communication_action_mapping={})
|
387 |
+
|
388 |
+
def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh):
|
389 |
+
name = (
|
390 |
+
f'{DimSpec(mesh_dim)}b{batch_dim}RR = '
|
391 |
+
f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR'
|
392 |
+
)
|
393 |
+
in0_data = self.op_data['input0']
|
394 |
+
in1_data = self.op_data['input1']
|
395 |
+
|
396 |
+
batch_partition_dict = {batch_dim: mesh_dim}
|
397 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
398 |
+
batch_partition_dict, in0_data)
|
399 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
400 |
+
batch_partition_dict, in1_data)
|
401 |
+
out_partition_dict = {batch_dim: mesh_dim}
|
402 |
+
# TODO:[KDuan] Double check if MatrixMultiplication's output has bcast in dim
|
403 |
+
dim_partition_dict_mapping = {
|
404 |
+
"input0": in0_parition_dict,
|
405 |
+
"input1": in1_parition_dict,
|
406 |
+
"output0": out_partition_dict,
|
407 |
+
}
|
408 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
409 |
+
dim_partition_dict_mapping, device_mesh)
|
410 |
+
if len(sharding_spec_mapping) == 0:
|
411 |
+
return None
|
412 |
+
return self._get_sharding_strategy(
|
413 |
+
name=name,
|
414 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
415 |
+
communication_action_mapping={})
|
416 |
+
|
417 |
+
def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0,
|
418 |
+
mesh_dim1, device_mesh):
|
419 |
+
name = (
|
420 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = '
|
421 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR'
|
422 |
+
)
|
423 |
+
in0_data = self.op_data['input0']
|
424 |
+
in1_data = self.op_data['input1']
|
425 |
+
|
426 |
+
in0_parition_dict = {}
|
427 |
+
if batch_dim0 not in in0_data.attrs["broadcast_dims"]:
|
428 |
+
in0_parition_dict[batch_dim0] = mesh_dim0
|
429 |
+
if batch_dim1 not in in0_data.attrs["broadcast_dims"]:
|
430 |
+
in0_parition_dict[batch_dim1] = mesh_dim1
|
431 |
+
|
432 |
+
in1_parition_dict = {}
|
433 |
+
if batch_dim0 not in in1_data.attrs["broadcast_dims"]:
|
434 |
+
in1_parition_dict[batch_dim0] = mesh_dim0
|
435 |
+
if batch_dim1 not in in1_data.attrs["broadcast_dims"]:
|
436 |
+
in1_parition_dict[batch_dim1] = mesh_dim1
|
437 |
+
|
438 |
+
batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
|
439 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
440 |
+
batch_partition_dict, in0_data)
|
441 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
442 |
+
batch_partition_dict, in1_data)
|
443 |
+
out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
|
444 |
+
dim_partition_dict_mapping = {
|
445 |
+
"input0": in0_parition_dict,
|
446 |
+
"input1": in1_parition_dict,
|
447 |
+
"output0": out_partition_dict,
|
448 |
+
}
|
449 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
450 |
+
dim_partition_dict_mapping, device_mesh)
|
451 |
+
if len(sharding_spec_mapping) == 0:
|
452 |
+
return None
|
453 |
+
return self._get_sharding_strategy(
|
454 |
+
name=name,
|
455 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
456 |
+
communication_action_mapping={})
|
457 |
+
|
458 |
+
def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
|
459 |
+
device_mesh):
|
460 |
+
|
461 |
+
name = (
|
462 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = '
|
463 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR'
|
464 |
+
)
|
465 |
+
in0_data = self.op_data['input0']
|
466 |
+
in1_data = self.op_data['input1']
|
467 |
+
in0_parition_dict = {batch_dim: mesh_dim0}
|
468 |
+
in1_parition_dict = {batch_dim: mesh_dim0}
|
469 |
+
in0_lhs_split_dim = -1 if self.op0_transpose else -2
|
470 |
+
in0_parition_dict[in0_lhs_split_dim] = mesh_dim1
|
471 |
+
|
472 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
473 |
+
in0_parition_dict, in0_data)
|
474 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
475 |
+
in1_parition_dict, in1_data)
|
476 |
+
out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1}
|
477 |
+
|
478 |
+
dim_partition_dict_mapping = {
|
479 |
+
"input0": in0_parition_dict,
|
480 |
+
"input1": in1_parition_dict,
|
481 |
+
"output0": out_partition_dict,
|
482 |
+
}
|
483 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
484 |
+
dim_partition_dict_mapping, device_mesh)
|
485 |
+
if len(sharding_spec_mapping) == 0:
|
486 |
+
return None
|
487 |
+
return self._get_sharding_strategy(
|
488 |
+
name=name,
|
489 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
490 |
+
communication_action_mapping={})
|
491 |
+
|
492 |
+
def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
|
493 |
+
device_mesh):
|
494 |
+
|
495 |
+
name = (
|
496 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = '
|
497 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}'
|
498 |
+
)
|
499 |
+
in0_data = self.op_data['input0']
|
500 |
+
in1_data = self.op_data['input1']
|
501 |
+
in0_parition_dict = {batch_dim: mesh_dim0}
|
502 |
+
in1_parition_dict = {batch_dim: mesh_dim0}
|
503 |
+
|
504 |
+
in1_rhs_split_dim = -2 if self.op1_transpose else -1
|
505 |
+
in1_parition_dict[in1_rhs_split_dim] = mesh_dim1
|
506 |
+
|
507 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
508 |
+
in0_parition_dict, in0_data)
|
509 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
510 |
+
in1_parition_dict, in1_data)
|
511 |
+
out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1}
|
512 |
+
dim_partition_dict_mapping = {
|
513 |
+
"input0": in0_parition_dict,
|
514 |
+
"input1": in1_parition_dict,
|
515 |
+
"output0": out_partition_dict,
|
516 |
+
}
|
517 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
518 |
+
dim_partition_dict_mapping, device_mesh)
|
519 |
+
if len(sharding_spec_mapping) == 0:
|
520 |
+
return None
|
521 |
+
return self._get_sharding_strategy(
|
522 |
+
name=name,
|
523 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
524 |
+
communication_action_mapping={})
|
525 |
+
|
526 |
+
def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1,
|
527 |
+
device_mesh):
|
528 |
+
|
529 |
+
name = (
|
530 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
|
531 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
|
532 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
|
533 |
+
)
|
534 |
+
in0_data = self.op_data['input0']
|
535 |
+
in1_data = self.op_data['input1']
|
536 |
+
in0_parition_dict = {batch_dim: mesh_dim0}
|
537 |
+
in1_parition_dict = {batch_dim: mesh_dim0}
|
538 |
+
|
539 |
+
in0_contract_dim = -2 if self.op0_transpose else -1
|
540 |
+
in1_contract_dim = -1 if self.op1_transpose else -2
|
541 |
+
in0_parition_dict[in0_contract_dim] = mesh_dim1
|
542 |
+
in1_parition_dict[in1_contract_dim] = mesh_dim1
|
543 |
+
|
544 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
545 |
+
in0_parition_dict, in0_data)
|
546 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
547 |
+
in1_parition_dict, in1_data)
|
548 |
+
out_partition_dict = {batch_dim: mesh_dim0}
|
549 |
+
dim_partition_dict_mapping = {
|
550 |
+
"input0": in0_parition_dict,
|
551 |
+
"input1": in1_parition_dict,
|
552 |
+
"output0": out_partition_dict,
|
553 |
+
}
|
554 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
555 |
+
dim_partition_dict_mapping, device_mesh)
|
556 |
+
if len(sharding_spec_mapping) == 0:
|
557 |
+
return None
|
558 |
+
|
559 |
+
# get communication actions
|
560 |
+
communication_action_mapping = {}
|
561 |
+
output0_comm_action = CommSpec(
|
562 |
+
comm_pattern='all_reduce',
|
563 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
564 |
+
logical_process_axis=[mesh_dim1],
|
565 |
+
)
|
566 |
+
communication_action_mapping['output0'] = output0_comm_action
|
567 |
+
return self._get_sharding_strategy(
|
568 |
+
name=name,
|
569 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
570 |
+
communication_action_mapping=communication_action_mapping)
|
571 |
+
|
572 |
+
def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1,
|
573 |
+
device_mesh):
|
574 |
+
|
575 |
+
name = (
|
576 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
|
577 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
|
578 |
+
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
|
579 |
+
)
|
580 |
+
in0_data = self.op_data['input0']
|
581 |
+
in1_data = self.op_data['input1']
|
582 |
+
in0_parition_dict = {batch_dim: mesh_dim0}
|
583 |
+
in1_parition_dict = {batch_dim: mesh_dim0}
|
584 |
+
|
585 |
+
in0_contract_dim = -2 if self.op0_transpose else -1
|
586 |
+
in1_contract_dim = -1 if self.op1_transpose else -2
|
587 |
+
in0_parition_dict[in0_contract_dim] = mesh_dim1
|
588 |
+
in1_parition_dict[in1_contract_dim] = mesh_dim1
|
589 |
+
|
590 |
+
in0_parition_dict = self._recover_bcast_partition_dict(
|
591 |
+
in0_parition_dict, in0_data)
|
592 |
+
in1_parition_dict = self._recover_bcast_partition_dict(
|
593 |
+
in1_parition_dict, in1_data)
|
594 |
+
out_partition_dict = {batch_dim: mesh_dim0}
|
595 |
+
dim_partition_dict_mapping = {
|
596 |
+
"input0": in0_parition_dict,
|
597 |
+
"input1": in1_parition_dict,
|
598 |
+
"output0": out_partition_dict,
|
599 |
+
}
|
600 |
+
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
601 |
+
dim_partition_dict_mapping, device_mesh)
|
602 |
+
if len(mm_out_sharding_spec_mapping) == 0:
|
603 |
+
return []
|
604 |
+
|
605 |
+
strategies = []
|
606 |
+
for rs_dim in range(self.num_out_dims):
|
607 |
+
if rs_dim != batch_dim:
|
608 |
+
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
609 |
+
'R'
|
610 |
+
] * self.num_out_dims, ['R'] * self.num_out_dims
|
611 |
+
name_in0[batch_dim], name_in0[-1] = str(
|
612 |
+
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
613 |
+
name_in1[batch_dim], name_in1[-2] = str(
|
614 |
+
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
615 |
+
name_in1[batch_dim], name_out0[rs_dim] = str(
|
616 |
+
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
617 |
+
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
618 |
+
name_in1), ', '.join(name_out0)
|
619 |
+
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}'
|
620 |
+
ret = self._split_both_contract_rs(
|
621 |
+
name, rs_dim, mesh_dim1,
|
622 |
+
mm_out_sharding_spec_mapping['output0'],
|
623 |
+
dim_partition_dict_mapping, device_mesh)
|
624 |
+
if ret:
|
625 |
+
strategies.append(ret)
|
626 |
+
return strategies
|
627 |
+
|
628 |
+
def _dp_strategies(self, device_mesh):
|
629 |
+
strategies = []
|
630 |
+
# S0R = S0R x RR
|
631 |
+
strategies.append(self._split_lhs_space_only([0], device_mesh))
|
632 |
+
# S1R = S1R x RR
|
633 |
+
strategies.append(self._split_lhs_space_only([1], device_mesh))
|
634 |
+
# S01R = S01R x RR
|
635 |
+
strategies.append(self._split_lhs_space_only([0, 1], device_mesh))
|
636 |
+
return strategies
|
637 |
+
|
638 |
+
def _tp_strategies(self, device_mesh: LogicalDeviceMesh):
|
639 |
+
strategies = []
|
640 |
+
# RR = RS x SR _ AR
|
641 |
+
strategies.append(self._recompute_split_both_contract([0], device_mesh))
|
642 |
+
strategies.append(self._recompute_split_both_contract([1], device_mesh))
|
643 |
+
strategies.append(
|
644 |
+
self._recompute_split_both_contract([0, 1], device_mesh))
|
645 |
+
|
646 |
+
if device_mesh.config.enable_reduce_scatter:
|
647 |
+
# RS x SR _ reduce scatter
|
648 |
+
strategies.extend(
|
649 |
+
self._recompute_split_both_contract_rs([0], device_mesh))
|
650 |
+
strategies.extend(
|
651 |
+
self._recompute_split_both_contract_rs([1], device_mesh))
|
652 |
+
strategies.extend(
|
653 |
+
self._recompute_split_both_contract_rs([0, 1], device_mesh))
|
654 |
+
|
655 |
+
# RS = RR x RS
|
656 |
+
strategies.append(self._split_rhs_space_only([0], device_mesh))
|
657 |
+
strategies.append(self._split_rhs_space_only([1], device_mesh))
|
658 |
+
strategies.append(self._split_rhs_space_only([0, 1], device_mesh))
|
659 |
+
|
660 |
+
# RS = RS x SS _ AR
|
661 |
+
strategies.append(
|
662 |
+
self._split_rhs_space_both_contract([0], [1], device_mesh))
|
663 |
+
strategies.append(
|
664 |
+
self._split_rhs_space_both_contract([1], [0], device_mesh))
|
665 |
+
|
666 |
+
if device_mesh.config.enable_reduce_scatter:
|
667 |
+
# RS x SS _ reduce scatter
|
668 |
+
strategies.extend(
|
669 |
+
self._split_rhs_space_both_contract_rs([0], [1], device_mesh))
|
670 |
+
strategies.extend(
|
671 |
+
self._split_rhs_space_both_contract_rs([1], [0], device_mesh))
|
672 |
+
|
673 |
+
return strategies
|
674 |
+
|
675 |
+
def _mix_strategies(self, device_mesh):
|
676 |
+
strategies = []
|
677 |
+
|
678 |
+
# SR = SS x SR_AR
|
679 |
+
strategies.append(
|
680 |
+
self._split_lhs_space_both_contract([0], [1], device_mesh))
|
681 |
+
strategies.append(
|
682 |
+
self._split_lhs_space_both_contract([1], [0], device_mesh))
|
683 |
+
if device_mesh.config.enable_reduce_scatter:
|
684 |
+
# RS x SS _ reduce scatter
|
685 |
+
strategies.extend(
|
686 |
+
self._split_lhs_space_both_contract_rs([0], [1], device_mesh))
|
687 |
+
strategies.extend(
|
688 |
+
self._split_lhs_space_both_contract_rs([1], [0], device_mesh))
|
689 |
+
# SS = SR x RS
|
690 |
+
strategies.append(self._split_lhs_space_rhs_space([0], [1],
|
691 |
+
device_mesh))
|
692 |
+
strategies.append(self._split_lhs_space_rhs_space([0], [1],
|
693 |
+
device_mesh))
|
694 |
+
|
695 |
+
# RR = RR x RR
|
696 |
+
strategies.append(self._non_split(device_mesh))
|
697 |
+
return strategies
|
698 |
+
|
699 |
+
def _bmm_strategies(self, device_mesh: LogicalDeviceMesh):
|
700 |
+
strategies = []
|
701 |
+
bmm_dim = len(self.op_data['output0'].shape)
|
702 |
+
if bmm_dim >= 3:
|
703 |
+
for batch_dim in range(0, bmm_dim - 2):
|
704 |
+
strategies.append(
|
705 |
+
self._split_one_batch_dim(batch_dim, [0], device_mesh))
|
706 |
+
strategies.append(
|
707 |
+
self._split_one_batch_dim(batch_dim, [1], device_mesh))
|
708 |
+
strategies.append(
|
709 |
+
self._split_one_batch_dim(batch_dim, [0, 1], device_mesh))
|
710 |
+
|
711 |
+
strategies.append(
|
712 |
+
self._split_batch_dim_lhs_space(batch_dim, [0], [1],
|
713 |
+
device_mesh))
|
714 |
+
strategies.append(
|
715 |
+
self._split_batch_dim_lhs_space(batch_dim, [1], [0],
|
716 |
+
device_mesh))
|
717 |
+
|
718 |
+
strategies.append(
|
719 |
+
self._split_batch_dim_rhs_space(batch_dim, [0], [1],
|
720 |
+
device_mesh))
|
721 |
+
strategies.append(
|
722 |
+
self._split_batch_dim_rhs_space(batch_dim, [1], [0],
|
723 |
+
device_mesh))
|
724 |
+
|
725 |
+
strategies.append(
|
726 |
+
self._split_batch_dim_both_contract(batch_dim, [0], [1],
|
727 |
+
device_mesh))
|
728 |
+
strategies.append(
|
729 |
+
self._split_batch_dim_both_contract(batch_dim, [1], [0],
|
730 |
+
device_mesh))
|
731 |
+
if device_mesh.config.enable_reduce_scatter:
|
732 |
+
strategies.extend(
|
733 |
+
self._split_batch_dim_both_contract_rs(
|
734 |
+
batch_dim, [0], [1], device_mesh))
|
735 |
+
strategies.extend(
|
736 |
+
self._split_batch_dim_both_contract_rs(
|
737 |
+
batch_dim, [1], [0], device_mesh))
|
738 |
+
if bmm_dim >= 4:
|
739 |
+
for batch_dim0 in range(0, bmm_dim - 2):
|
740 |
+
for batch_dim1 in range(0, bmm_dim - 2):
|
741 |
+
if batch_dim0 != batch_dim1:
|
742 |
+
strategies.append(
|
743 |
+
self._split_two_batch_dims(
|
744 |
+
batch_dim0, batch_dim1, [0], [1],
|
745 |
+
device_mesh))
|
746 |
+
|
747 |
+
return strategies
|
748 |
+
|
749 |
+
def _collect_strategies(self, device_mesh):
|
750 |
+
strategies_vector = StrategiesVector(self)
|
751 |
+
dp_strategies = self._dp_strategies(device_mesh)
|
752 |
+
tp_strategies = self._tp_strategies(device_mesh)
|
753 |
+
mix_strategies = self._mix_strategies(device_mesh)
|
754 |
+
bmm_strategies = self._bmm_strategies(device_mesh)
|
755 |
+
strategies_vector.extend(dp_strategies)
|
756 |
+
strategies_vector.extend(tp_strategies)
|
757 |
+
strategies_vector.extend(mix_strategies)
|
758 |
+
strategies_vector.extend(bmm_strategies)
|
759 |
+
return strategies_vector
|
760 |
+
|
761 |
+
def is_fp16(self):
|
762 |
+
builder_flags = get_builder_flags()
|
763 |
+
return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0
|
764 |
+
|
765 |
+
def _get_math_time(self, strategy, device_mesh):
|
766 |
+
shape_in0 = strategy.sharding_specs[
|
767 |
+
'input0'].get_sharded_shape_per_device()
|
768 |
+
shape_out = strategy.sharding_specs[
|
769 |
+
'output0'].get_sharded_shape_per_device()
|
770 |
+
m, n = shape_out[-2], shape_out[-1]
|
771 |
+
batches = shape_out[:-2]
|
772 |
+
k = shape_in0[-2] if self.op0_transpose else shape_in0[-1]
|
773 |
+
macs_shape = batches + [m, n, k]
|
774 |
+
macs = reduce(operator.mul, macs_shape, 1) * 2
|
775 |
+
config = device_mesh.config
|
776 |
+
cluster_info = device_mesh.cluster_info
|
777 |
+
dtype = self.dtype
|
778 |
+
# For fp16 matmul ops that use_fp32_acc=True.
|
779 |
+
# They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype.
|
780 |
+
if self.is_fp16() and self.dtype == "float32":
|
781 |
+
dtype = "float16"
|
782 |
+
math_throughput_tflops = getattr(cluster_info.math_throughput, dtype)
|
783 |
+
assert math_throughput_tflops != 0, \
|
784 |
+
"Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key)
|
785 |
+
math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency
|
786 |
+
return math_time
|
787 |
+
|
788 |
+
def _update_memory_cost(self, strategies):
|
789 |
+
super()._update_memory_cost(strategies)
|
790 |
+
# For fp16 matmul ops that use_fp32_acc=True.
|
791 |
+
# Their memory footprints are calculated based on fp32 IO tensors.
|
792 |
+
# Actually they will use fp16 IO tensors after fused.
|
793 |
+
# So we divide all the memory footprints by 2.
|
794 |
+
if self.is_fp16() and self.dtype == "float32":
|
795 |
+
for strategy in strategies:
|
796 |
+
strategy.inout_memory_footprint /= 2
|
797 |
+
strategy.peak_memory_footprint /= 2
|
798 |
+
strategy.comm_buff_memory_footprint /= 2
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/node.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
from ..config import CostModel
|
4 |
+
from ..device_mesh import LogicalDeviceMesh
|
5 |
+
from .comm_spec import CommSpec
|
6 |
+
from .sharding_spec import ShardingSpec
|
7 |
+
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
8 |
+
|
9 |
+
|
10 |
+
class Node(ABC):
|
11 |
+
|
12 |
+
def __init__(self, layer):
|
13 |
+
self._layer = layer
|
14 |
+
self.is_shape_io = self._layer.is_shape_io
|
15 |
+
self._inputs = []
|
16 |
+
self._outputs = []
|
17 |
+
self.predecessor_nodes = []
|
18 |
+
self.predecessor_nodes_out_index = {}
|
19 |
+
self.successor_nodes = []
|
20 |
+
self.op_data = {}
|
21 |
+
self.global_to_local_op_name = {}
|
22 |
+
self.num_inputs = 0
|
23 |
+
self.is_replicated = layer.attrs.get("is_replicated", False)
|
24 |
+
self.same_spec_id = layer.attrs.get("same_spec_id", -1)
|
25 |
+
self.is_fake = self.same_spec_id != -1
|
26 |
+
self.building_block_id = layer.attrs.get("building_block_id", -1)
|
27 |
+
self.cost_level = -1
|
28 |
+
self.stage_type = layer.attrs.get("stage_type", None)
|
29 |
+
self.in_start_block = layer.attrs.get("in_start_block", False)
|
30 |
+
self.in_end_block = layer.attrs.get("in_end_block", False)
|
31 |
+
self.in_slowest_block = layer.attrs.get("in_slowest_block", False)
|
32 |
+
for i, input in enumerate(layer.inputs):
|
33 |
+
if input is None:
|
34 |
+
self._inputs.append(None)
|
35 |
+
self.op_data[f'input{i}'] = None
|
36 |
+
continue
|
37 |
+
input = input.copy()
|
38 |
+
input.attrs["broadcast_dims"] = []
|
39 |
+
self._inputs.append(input)
|
40 |
+
self.op_data[f'input{i}'] = input
|
41 |
+
self.global_to_local_op_name[input.name] = f'input{i}'
|
42 |
+
|
43 |
+
for i, output in enumerate(layer.outputs):
|
44 |
+
output = output.copy()
|
45 |
+
output.attrs["broadcast_dims"] = []
|
46 |
+
self._outputs.append(output)
|
47 |
+
self.op_data[f'output{i}'] = output
|
48 |
+
self.global_to_local_op_name[output.name] = f'output{i}'
|
49 |
+
|
50 |
+
self.sharding_weight = 1.0
|
51 |
+
self.resharding_weight = 1.0
|
52 |
+
self.pipeline_weight = 0
|
53 |
+
self.node_name = layer.name
|
54 |
+
self.node_type = 'normal_node'
|
55 |
+
self.num_inputs = layer.num_inputs
|
56 |
+
self.num_outputs = layer.num_outputs
|
57 |
+
self.dtype = layer.as_trt().precision
|
58 |
+
self.strategies_vector = []
|
59 |
+
self.node_runtime_profiler = None
|
60 |
+
|
61 |
+
def post_init(self, graph):
|
62 |
+
for input in self.inputs:
|
63 |
+
if input is None:
|
64 |
+
self.predecessor_nodes.append(None)
|
65 |
+
continue
|
66 |
+
if input.producer is None:
|
67 |
+
predecessor_node = graph.get_node(input.name)
|
68 |
+
self.predecessor_nodes.append(predecessor_node)
|
69 |
+
self.predecessor_nodes_out_index[predecessor_node] = 0
|
70 |
+
predecessor_node.successor_nodes.append(self)
|
71 |
+
else:
|
72 |
+
predecessor_node = graph.get_node(input.producer.name)
|
73 |
+
self.predecessor_nodes.append(predecessor_node)
|
74 |
+
self.predecessor_nodes_out_index[
|
75 |
+
predecessor_node] = input.output_index
|
76 |
+
predecessor_node.successor_nodes.append(self)
|
77 |
+
|
78 |
+
@property
|
79 |
+
def layer(self):
|
80 |
+
return self._layer
|
81 |
+
|
82 |
+
def get_input(self, index):
|
83 |
+
return self._inputs[index]
|
84 |
+
|
85 |
+
@property
|
86 |
+
def inputs(self):
|
87 |
+
return self._inputs
|
88 |
+
|
89 |
+
def get_output(self, index):
|
90 |
+
return self._outputs[index]
|
91 |
+
|
92 |
+
@property
|
93 |
+
def outputs(self):
|
94 |
+
return self._outputs
|
95 |
+
|
96 |
+
def collect_strategies(self, device_mesh):
|
97 |
+
strategies_vector = self._collect_strategies(device_mesh)
|
98 |
+
strategies_vector = self._post_process(strategies_vector)
|
99 |
+
self._update_sharding_cost(strategies_vector, device_mesh)
|
100 |
+
self.strategies_vector = strategies_vector
|
101 |
+
return self.strategies_vector
|
102 |
+
|
103 |
+
def _set_strategy(self, strategy, device_mesh):
|
104 |
+
strategies_vector = StrategiesVector(self)
|
105 |
+
if strategy is None:
|
106 |
+
dim_partition_dict_mapping = {}
|
107 |
+
for i in range(self.num_inputs):
|
108 |
+
dim_partition_dict_mapping[f'input{i}'] = {}
|
109 |
+
for i in range(self.num_outputs):
|
110 |
+
dim_partition_dict_mapping[f'output{i}'] = {}
|
111 |
+
|
112 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
113 |
+
dim_partition_dict_mapping, device_mesh)
|
114 |
+
assert 0 != len(
|
115 |
+
sharding_spec_mapping
|
116 |
+
), f'failed to set default(all Replicate) strategy for node {self.node_name}'
|
117 |
+
name = 'RRs'
|
118 |
+
sharding_strategy = self._get_sharding_strategy(
|
119 |
+
name=name,
|
120 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
121 |
+
communication_action_mapping={})
|
122 |
+
strategies_vector.append(sharding_strategy)
|
123 |
+
|
124 |
+
else:
|
125 |
+
sharding_specs_map = strategy.sharding_specs
|
126 |
+
comm_specs_map = strategy.communication_actions
|
127 |
+
dim_partition_dict_mapping = {}
|
128 |
+
for op_name, sharding_spec in sharding_specs_map.items():
|
129 |
+
dim_partition_dict_mapping[
|
130 |
+
op_name] = sharding_spec.dim_partition_dict
|
131 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
132 |
+
dim_partition_dict_mapping, device_mesh)
|
133 |
+
assert 0 != len(
|
134 |
+
sharding_spec_mapping
|
135 |
+
), f'failed to set strategy for node {self.node_name}'
|
136 |
+
comm_specs_mapping = {}
|
137 |
+
if len(comm_specs_map) > 0:
|
138 |
+
for op_name, comm_spec in comm_specs_map.items():
|
139 |
+
comm_specs_mapping[op_name] = CommSpec(
|
140 |
+
comm_pattern=comm_spec.comm_pattern,
|
141 |
+
sharding_spec=sharding_spec_mapping[op_name],
|
142 |
+
logical_process_axis=comm_spec.logical_process_axis,
|
143 |
+
)
|
144 |
+
strategies_vector.append(
|
145 |
+
self._get_sharding_strategy(
|
146 |
+
name=strategy.name,
|
147 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
148 |
+
communication_action_mapping=comm_specs_mapping))
|
149 |
+
return strategies_vector
|
150 |
+
|
151 |
+
def set_strategy(self, strategy, device_mesh):
|
152 |
+
strategies_vector = self._set_strategy(strategy, device_mesh)
|
153 |
+
strategies_vector = self._post_process(strategies_vector)
|
154 |
+
self._update_sharding_cost(strategies_vector, device_mesh)
|
155 |
+
self.strategies_vector = strategies_vector
|
156 |
+
return self.strategies_vector
|
157 |
+
|
158 |
+
def update_resharding_cost(self):
|
159 |
+
self._update_resharding_cost(self.strategies_vector)
|
160 |
+
return self.strategies_vector
|
161 |
+
|
162 |
+
def _to_sharding_spec_mapping(self, dim_partition_dict_mapping,
|
163 |
+
device_mesh):
|
164 |
+
results = {}
|
165 |
+
for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items(
|
166 |
+
):
|
167 |
+
if op_data_name in self.op_data:
|
168 |
+
op_data = self.op_data[op_data_name]
|
169 |
+
|
170 |
+
def _to_sharding_spec(op_data, dim_partition_dict):
|
171 |
+
sharding_spec = ShardingSpec(
|
172 |
+
device_mesh,
|
173 |
+
op_data.dtype_str_size, [*op_data.shape],
|
174 |
+
[*op_data.max_shape], [*op_data.raw_shape],
|
175 |
+
dim_partition_dict=dim_partition_dict)
|
176 |
+
if sharding_spec.sanity_check():
|
177 |
+
return sharding_spec
|
178 |
+
else:
|
179 |
+
return None
|
180 |
+
|
181 |
+
sharding_spec = _to_sharding_spec(op_data, dim_partition_dict)
|
182 |
+
if sharding_spec:
|
183 |
+
results[op_data_name] = sharding_spec
|
184 |
+
else:
|
185 |
+
return {}
|
186 |
+
return results
|
187 |
+
|
188 |
+
def _get_sharding_strategy(self, name, sharding_spec_mapping,
|
189 |
+
communication_action_mapping):
|
190 |
+
return ShardingStrategy(
|
191 |
+
name=name,
|
192 |
+
sharding_specs=sharding_spec_mapping,
|
193 |
+
communication_actions=communication_action_mapping,
|
194 |
+
)
|
195 |
+
|
196 |
+
def _remove_duplicated_strategy(self, strategies_vector):
|
197 |
+
name_checklist = []
|
198 |
+
remove_list = []
|
199 |
+
for strategy in strategies_vector:
|
200 |
+
if strategy.name not in name_checklist:
|
201 |
+
name_checklist.append(strategy.name)
|
202 |
+
else:
|
203 |
+
remove_list.append(strategy)
|
204 |
+
for strategy in remove_list:
|
205 |
+
strategies_vector.remove(strategy)
|
206 |
+
|
207 |
+
def _post_process(self, strategies_vector):
|
208 |
+
# TODO:[KDuan] deal with transpose and dimension 1 problem in ClossalAI, which have been processed before
|
209 |
+
for i in range(len(strategies_vector) - 1, -1, -1):
|
210 |
+
if strategies_vector[i] is None:
|
211 |
+
strategies_vector.pop(i)
|
212 |
+
|
213 |
+
self._remove_duplicated_strategy(strategies_vector)
|
214 |
+
return strategies_vector
|
215 |
+
|
216 |
+
def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh):
|
217 |
+
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
218 |
+
self.layer, {}, {}, strategy, device_mesh)
|
219 |
+
return elapsed_time
|
220 |
+
|
221 |
+
def _model_sharding_cost_from_s_curve(self, strategy,
|
222 |
+
device_mesh: LogicalDeviceMesh):
|
223 |
+
'''
|
224 |
+
[ToDo][KDuan] preprofile the s_curve
|
225 |
+
'''
|
226 |
+
sharding_cost = 0.0
|
227 |
+
return sharding_cost
|
228 |
+
|
229 |
+
# this method might be overwritten by some Ops
|
230 |
+
def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh):
|
231 |
+
return 0.0
|
232 |
+
|
233 |
+
# this method might be overwritten by some Ops
|
234 |
+
def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh):
|
235 |
+
memory_time = (strategy.inout_memory_footprint /
|
236 |
+
device_mesh.cluster_info.memory_bw * 1e-3 *
|
237 |
+
device_mesh.cluster_info.memory_efficiency)
|
238 |
+
return memory_time
|
239 |
+
|
240 |
+
def _model_sharding_cost_from_alpha_beta(self, strategy,
|
241 |
+
device_mesh: LogicalDeviceMesh):
|
242 |
+
math_time = self._get_math_time(strategy, device_mesh)
|
243 |
+
mem_time = self._get_memory_time(strategy, device_mesh)
|
244 |
+
return max(math_time, mem_time)
|
245 |
+
|
246 |
+
def _get_communication_cost(self, strategy):
|
247 |
+
total_comm_cost = 0.0
|
248 |
+
for op_data_name, comm_spec in strategy.communication_actions.items():
|
249 |
+
comm_cost = comm_spec.get_comm_cost()
|
250 |
+
total_comm_cost = total_comm_cost + comm_cost
|
251 |
+
return total_comm_cost
|
252 |
+
|
253 |
+
def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh):
|
254 |
+
self._update_memory_cost(strategies)
|
255 |
+
|
256 |
+
if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA:
|
257 |
+
for strategy in strategies:
|
258 |
+
strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta(
|
259 |
+
strategy, device_mesh)
|
260 |
+
elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE:
|
261 |
+
for strategy in strategies:
|
262 |
+
strategy.sharding_cost = self._model_sharding_cost_from_s_curve(
|
263 |
+
strategy, device_mesh)
|
264 |
+
elif device_mesh.config.sharding_cost_model == CostModel.PROFILE:
|
265 |
+
for strategy in strategies:
|
266 |
+
strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta(
|
267 |
+
strategy, device_mesh)
|
268 |
+
if self.is_shape_io:
|
269 |
+
strategy.sharding_cost = strategy.alpha_beta_cost
|
270 |
+
else:
|
271 |
+
strategy.sharding_cost = self._profile_sharding_cost(
|
272 |
+
strategy, device_mesh)
|
273 |
+
elif device_mesh.config.sharding_cost_model == CostModel.ZERO:
|
274 |
+
for strategy in strategies:
|
275 |
+
strategy.sharding_cost = 0.0
|
276 |
+
else:
|
277 |
+
assert False, 'unsupport sharding cost model option: {}'.format(
|
278 |
+
device_mesh.config.sharding_cost_model)
|
279 |
+
|
280 |
+
for strategy in strategies:
|
281 |
+
strategy.communication_cost = self._get_communication_cost(strategy)
|
282 |
+
|
283 |
+
def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec,
|
284 |
+
op_data):
|
285 |
+
transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency(
|
286 |
+
pre_sharding_sepc, cur_sharding_spec)
|
287 |
+
return (transform_path, comm_action_sequence, resharding_cost)
|
288 |
+
|
289 |
+
def _update_resharding_cost(self, strategies):
|
290 |
+
for strategy in strategies:
|
291 |
+
resharding_costs = {}
|
292 |
+
for pre_node, out_index in self.predecessor_nodes_out_index.items():
|
293 |
+
if pre_node is None:
|
294 |
+
continue
|
295 |
+
pre_node_out_data_name = pre_node.get_output(out_index).name
|
296 |
+
pre_node_out_data_lname = pre_node.global_to_local_op_name[
|
297 |
+
pre_node_out_data_name]
|
298 |
+
if pre_node_out_data_name not in self.global_to_local_op_name:
|
299 |
+
print(f"pre_node_out_data_name = {pre_node_out_data_name}")
|
300 |
+
continue
|
301 |
+
cur_node_inp_data_lname = self.global_to_local_op_name[
|
302 |
+
pre_node_out_data_name]
|
303 |
+
cur_sharding_spec = strategy.sharding_specs[
|
304 |
+
cur_node_inp_data_lname]
|
305 |
+
|
306 |
+
pre_node_out_sharding_specs = []
|
307 |
+
for pre_strategy in pre_node.strategies_vector:
|
308 |
+
pre_node_out_sharding_specs.append(
|
309 |
+
pre_strategy.sharding_specs[pre_node_out_data_lname])
|
310 |
+
|
311 |
+
if pre_node not in resharding_costs:
|
312 |
+
resharding_costs[pre_node.node_name] = []
|
313 |
+
for prev_sharding_spec in pre_node_out_sharding_specs:
|
314 |
+
resharding_cost = self._compute_resharding_cost(
|
315 |
+
prev_sharding_spec, cur_sharding_spec,
|
316 |
+
self.op_data[cur_node_inp_data_lname])
|
317 |
+
resharding_costs[pre_node.node_name].append(resharding_cost)
|
318 |
+
strategy.resharding_costs = resharding_costs
|
319 |
+
|
320 |
+
def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size):
|
321 |
+
dim_partition_list = []
|
322 |
+
for i in range(dim_size):
|
323 |
+
dim_partition_list.append({i: mesh_dim})
|
324 |
+
return dim_partition_list
|
325 |
+
|
326 |
+
def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1,
|
327 |
+
dim_size):
|
328 |
+
dim_partition_list = []
|
329 |
+
for i in range(dim_size):
|
330 |
+
for j in range(dim_size):
|
331 |
+
if i != j:
|
332 |
+
dim_partition_list.append({i: mesh_dim0, j: mesh_dim1})
|
333 |
+
return dim_partition_list
|
334 |
+
|
335 |
+
def _update_memory_cost(self, strategies):
|
336 |
+
for strategy in strategies:
|
337 |
+
inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0
|
338 |
+
for spec in strategy.sharding_specs.values():
|
339 |
+
inout_memory_footprint += spec.get_sharded_size_per_device()
|
340 |
+
max_inout_memory_footprint += spec.get_max_sharded_size_per_device(
|
341 |
+
)
|
342 |
+
|
343 |
+
# the communication happens
|
344 |
+
comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0
|
345 |
+
for comm_spec in strategy.communication_actions.values():
|
346 |
+
comm_buffer_footprint += comm_spec.get_mem_cost()
|
347 |
+
max_comm_buffer_footprint += comm_spec.get_max_mem_cost()
|
348 |
+
|
349 |
+
# when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time
|
350 |
+
# rather than memory usage
|
351 |
+
strategy.inout_memory_footprint = inout_memory_footprint
|
352 |
+
|
353 |
+
strategy.comm_buff_memory_footprint = comm_buffer_footprint
|
354 |
+
strategy.peak_memory_footprint = max(max_inout_memory_footprint,
|
355 |
+
max_comm_buffer_footprint)
|
356 |
+
|
357 |
+
# The const memory (weight) is recorded in constant layers and should be accumulated
|
358 |
+
strategy.const_memory_footprint = 0.0
|
359 |
+
|
360 |
+
def _generate_bcast_dims(self, batch_dims, out_data_shape):
|
361 |
+
for output in self.outputs:
|
362 |
+
if output.broadcast_across_batch:
|
363 |
+
for bs in batch_dims:
|
364 |
+
if output.shape[
|
365 |
+
bs] == 1 and output.shape[bs] != out_data_shape[bs]:
|
366 |
+
output.attrs["broadcast_dims"].append(bs)
|
367 |
+
|
368 |
+
def _recover_bcast_partition_dict(self, partition_dict, op_data):
|
369 |
+
ret = {}
|
370 |
+
for data_dim, mesh_dim in partition_dict.items():
|
371 |
+
if data_dim not in op_data.attrs[
|
372 |
+
"broadcast_dims"] and data_dim + len(
|
373 |
+
op_data.shape) not in op_data.attrs[
|
374 |
+
"broadcast_dims"] and op_data.shape[data_dim] != 1:
|
375 |
+
ret[data_dim] = mesh_dim
|
376 |
+
return ret
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class Normalization(Node):
|
6 |
+
|
7 |
+
def __init__(self, layer):
|
8 |
+
super().__init__(layer)
|
9 |
+
layer.to_subclass()
|
10 |
+
self.axes = layer.as_trt().axes
|
11 |
+
self.weight_bias_dim_base = 0
|
12 |
+
layer.to_base_class()
|
13 |
+
|
14 |
+
def _collect_strategies(self, device_mesh):
|
15 |
+
dim_partition_list = []
|
16 |
+
dim_size = len(self.op_data['input0'].shape)
|
17 |
+
dim_partition_list.append({})
|
18 |
+
dim_partition_list.extend(
|
19 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
20 |
+
dim_partition_list.extend(
|
21 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
24 |
+
dim_partition_list.extend(
|
25 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
26 |
+
strategies_vector = StrategiesVector(self)
|
27 |
+
for dim_partition_dict in dim_partition_list:
|
28 |
+
shard_reduction_axes = False
|
29 |
+
for dim in range(len(self.get_input(0).shape)):
|
30 |
+
if (self.axes & (1 << dim)) and dim in dim_partition_dict:
|
31 |
+
shard_reduction_axes = True
|
32 |
+
break
|
33 |
+
if shard_reduction_axes:
|
34 |
+
continue
|
35 |
+
dim_partition_dict_mapping = {
|
36 |
+
"input0": dim_partition_dict,
|
37 |
+
"output0": dim_partition_dict,
|
38 |
+
}
|
39 |
+
if self.num_inputs >= 2:
|
40 |
+
dim_partition_dict_mapping['input1'] = {}
|
41 |
+
if self.num_inputs >= 3:
|
42 |
+
dim_partition_dict_mapping['input2'] = {}
|
43 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
44 |
+
dim_partition_dict_mapping, device_mesh)
|
45 |
+
if 0 == len(sharding_spec_mapping):
|
46 |
+
continue
|
47 |
+
name = '{} = {} <normalization op> scale {}, bias {}'.format(
|
48 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
49 |
+
sharding_spec_mapping['input0'].sharding_sequence,
|
50 |
+
sharding_spec_mapping['input1'].sharding_sequence
|
51 |
+
if self.num_inputs >= 2 else 'None',
|
52 |
+
sharding_spec_mapping['input2'].sharding_sequence
|
53 |
+
if self.num_inputs >= 3 else 'None',
|
54 |
+
)
|
55 |
+
sharding_strategy = self._get_sharding_strategy(
|
56 |
+
name=name,
|
57 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
58 |
+
communication_action_mapping={})
|
59 |
+
strategies_vector.append(sharding_strategy)
|
60 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class OuputNode(Node):
|
6 |
+
|
7 |
+
def _update_memory_cost(self, strategies):
|
8 |
+
for strategy in strategies:
|
9 |
+
if not self.no_memory_footprint:
|
10 |
+
strategy.const_memory_footprint = strategy.sharding_specs[
|
11 |
+
'input0'].get_max_sharded_size_per_device()
|
12 |
+
|
13 |
+
def __init__(self, tensor):
|
14 |
+
self._layer = None
|
15 |
+
self.is_shape_io = False
|
16 |
+
self._inputs = []
|
17 |
+
self._outputs = []
|
18 |
+
self.predecessor_nodes = []
|
19 |
+
self.predecessor_nodes_out_index = {}
|
20 |
+
self.successor_nodes = []
|
21 |
+
self.op_data = {}
|
22 |
+
self.global_to_local_op_name = {}
|
23 |
+
self.is_replicated = tensor.attrs.get("is_replicated", False)
|
24 |
+
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
|
25 |
+
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
|
26 |
+
False)
|
27 |
+
self.building_block_id = -1
|
28 |
+
self.cost_level = -1
|
29 |
+
self.stage_type = None
|
30 |
+
self.in_start_block = None
|
31 |
+
self.in_end_block = None
|
32 |
+
self.in_slowest_block = None
|
33 |
+
input = tensor.copy()
|
34 |
+
self._inputs.append(input)
|
35 |
+
self.op_data['input0'] = input
|
36 |
+
self.global_to_local_op_name[input.name] = 'input0'
|
37 |
+
|
38 |
+
self.sharding_weight = 1.0
|
39 |
+
self.resharding_weight = 1.0
|
40 |
+
self.pipeline_weight = 0
|
41 |
+
self.node_name = tensor.name
|
42 |
+
self.node_type = 'output_node'
|
43 |
+
self.num_inputs = 0
|
44 |
+
self.num_outputs = 1
|
45 |
+
self.dtype = tensor.dtype
|
46 |
+
self.strategies_vector = []
|
47 |
+
self.node_runtime_profiler = None
|
48 |
+
|
49 |
+
def _collect_strategies(self, device_mesh):
|
50 |
+
dim_partition_list = []
|
51 |
+
dim_size = len(self.op_data['input0'].shape)
|
52 |
+
dim_partition_list.append({})
|
53 |
+
dim_partition_list.extend(
|
54 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
55 |
+
dim_partition_list.extend(
|
56 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
57 |
+
dim_partition_list.extend(
|
58 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
59 |
+
dim_partition_list.extend(
|
60 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
61 |
+
|
62 |
+
strategies_vector = StrategiesVector(self)
|
63 |
+
for dim_partition_dict in dim_partition_list:
|
64 |
+
dim_partition_dict_mapping = {'input0': dim_partition_dict}
|
65 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
66 |
+
dim_partition_dict_mapping, device_mesh)
|
67 |
+
if 0 == len(sharding_spec_mapping):
|
68 |
+
continue
|
69 |
+
sharding_seq = sharding_spec_mapping['input0'].sharding_sequence
|
70 |
+
sharding_strategy = self._get_sharding_strategy(
|
71 |
+
name=f'output-op {sharding_seq}',
|
72 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
73 |
+
communication_action_mapping={})
|
74 |
+
strategies_vector.append(sharding_strategy)
|
75 |
+
|
76 |
+
return strategies_vector
|
77 |
+
|
78 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
79 |
+
return 0.0
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
from .comm_spec import CommSpec
|
5 |
+
from .identity_node import Identity
|
6 |
+
from .sharding_strategy import StrategiesVector
|
7 |
+
|
8 |
+
|
9 |
+
class P2PType(Enum):
|
10 |
+
CROSS_DEVICE = 0
|
11 |
+
CROSS_HOST = 1
|
12 |
+
|
13 |
+
|
14 |
+
class P2PNode(Identity):
|
15 |
+
|
16 |
+
def __init__(self, layer):
|
17 |
+
super().__init__(layer)
|
18 |
+
self.p2p_type = layer.attrs["p2p_type"]
|
19 |
+
self.is_fake = True
|
20 |
+
|
21 |
+
def _collect_strategies(self, device_mesh):
|
22 |
+
# one input for softmax node
|
23 |
+
predecessor = self.predecessor_nodes[0]
|
24 |
+
strategies_vector = StrategiesVector(self)
|
25 |
+
for idx, strategy in enumerate(predecessor.strategies_vector):
|
26 |
+
# current node's local name input0 -> global name xxx
|
27 |
+
global_input_name = self.op_data['input0'].name
|
28 |
+
# global name xxx -> pre node local output name
|
29 |
+
prenode_local_name = predecessor.global_to_local_op_name[
|
30 |
+
global_input_name]
|
31 |
+
dim_partition_dict = copy.deepcopy(
|
32 |
+
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
33 |
+
in0_partition_dict = dim_partition_dict
|
34 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
35 |
+
dim_partition_dict_mapping = {
|
36 |
+
"input0": in0_partition_dict,
|
37 |
+
"output0": out_partition_dict,
|
38 |
+
}
|
39 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
40 |
+
dim_partition_dict_mapping, device_mesh)
|
41 |
+
if 0 == len(sharding_spec_mapping):
|
42 |
+
continue
|
43 |
+
|
44 |
+
logical_process_axis = [
|
45 |
+
['p2p_cross_device']
|
46 |
+
] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']]
|
47 |
+
# get communication action mapping
|
48 |
+
communication_action_mapping = {}
|
49 |
+
output0_comm_action = CommSpec(
|
50 |
+
comm_pattern='peer_to_peer',
|
51 |
+
sharding_spec=sharding_spec_mapping['output0'],
|
52 |
+
logical_process_axis=logical_process_axis,
|
53 |
+
)
|
54 |
+
communication_action_mapping['output0'] = output0_comm_action
|
55 |
+
|
56 |
+
name = '{} = <P2P op> {}'.format(
|
57 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
58 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
59 |
+
sharding_strategy = self._get_sharding_strategy(
|
60 |
+
name=name,
|
61 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
62 |
+
communication_action_mapping=communication_action_mapping)
|
63 |
+
strategies_vector.append(sharding_strategy)
|
64 |
+
return strategies_vector
|
65 |
+
|
66 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
67 |
+
return 0.0
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorrt_llm.network import PluginInfo, get_plugin_info
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class PluginNode(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
self.plugin = layer.as_trt().plugin
|
13 |
+
self.plugin_type: str = self.plugin.plugin_type
|
14 |
+
self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(),
|
15 |
+
layer.name)
|
16 |
+
layer.to_base_class()
|
17 |
+
|
18 |
+
def _collect_strategies(self, device_mesh):
|
19 |
+
raise NotImplementedError(
|
20 |
+
f"Auto parallel does not support {self.plugin_type} plugin right now."
|
21 |
+
)
|
22 |
+
|
23 |
+
def _default_strategy(self, device_mesh):
|
24 |
+
strategies_vector = StrategiesVector(self)
|
25 |
+
dim_partition_dict_mapping = {}
|
26 |
+
for idx in range(self.num_inputs):
|
27 |
+
dim_partition_dict_mapping[f'input{idx}'] = {}
|
28 |
+
for idx in range(self.num_outputs):
|
29 |
+
dim_partition_dict_mapping[f'output{idx}'] = {}
|
30 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
31 |
+
dim_partition_dict_mapping, device_mesh)
|
32 |
+
if 0 == len(sharding_spec_mapping):
|
33 |
+
return strategies_vector
|
34 |
+
name = '{}_all_replicate'.format(self.plugin_type)
|
35 |
+
sharding_strategy = self._get_sharding_strategy(
|
36 |
+
name=name,
|
37 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
38 |
+
communication_action_mapping={})
|
39 |
+
strategies_vector.append(sharding_strategy)
|
40 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py
ADDED
File without changes
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorrt as trt
|
2 |
+
|
3 |
+
from tensorrt_llm._utils import trt_dtype_to_str
|
4 |
+
|
5 |
+
from ..matmul_node import MatrixMultiply
|
6 |
+
from ..plugin_node import PluginNode
|
7 |
+
|
8 |
+
|
9 |
+
class GemmPlugin(MatrixMultiply, PluginNode):
|
10 |
+
|
11 |
+
def __init__(self, layer):
|
12 |
+
PluginNode.__init__(self, layer)
|
13 |
+
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
|
14 |
+
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
15 |
+
pfc_as_list = self.plugin_info.pfc_as_list
|
16 |
+
self.op0_transpose = (pfc_as_list['transa'][0] == 1)
|
17 |
+
self.op1_transpose = (pfc_as_list['transb'][0] == 1)
|
18 |
+
self.num_out_dims = len(self.get_output(0).shape)
|
19 |
+
self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0]))
|
20 |
+
|
21 |
+
def _collect_strategies(self, device_mesh):
|
22 |
+
strategies_vector = MatrixMultiply._collect_strategies(
|
23 |
+
self, device_mesh)
|
24 |
+
return strategies_vector
|
25 |
+
|
26 |
+
def _get_math_time(self, strategy, device_mesh):
|
27 |
+
return MatrixMultiply._get_math_time(self, strategy, device_mesh)
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum, auto
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from tensorrt_llm.functional import PositionEmbeddingType
|
7 |
+
from tensorrt_llm.quantization import QuantMode
|
8 |
+
|
9 |
+
from ..plugin_node import PluginNode
|
10 |
+
from ..sharding_strategy import StrategiesVector
|
11 |
+
|
12 |
+
|
13 |
+
# WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
|
14 |
+
class IdxEntry(Enum):
|
15 |
+
QKV_TENSOR = auto()
|
16 |
+
K_TENSOR = auto()
|
17 |
+
V_TENSOR = auto()
|
18 |
+
SEQUENCE_LENGTH = auto()
|
19 |
+
HOST_PAST_KEY_VALUE_LENGTHS = auto()
|
20 |
+
HOST_MAX_ATTENTION_WINDOW = auto()
|
21 |
+
HOST_SINK_TOKEN_LENGTH = auto()
|
22 |
+
CONTEXT_LENGTHS = auto()
|
23 |
+
CACHE_INDIR = auto()
|
24 |
+
REQUEST_TYPES = auto()
|
25 |
+
KV_CACHE_BLOCK_OFFSETS = auto()
|
26 |
+
HOST_KV_CACHE_BLOCK_OFFSETS = auto()
|
27 |
+
HOST_KV_CACHE_POOL_POINTERS = auto()
|
28 |
+
PAST_KEY_VALUE = auto()
|
29 |
+
KV_CACHE_QUANTIZATION_SCALE = auto()
|
30 |
+
KV_CACHE_DEQUANTIZATION_SCALE = auto()
|
31 |
+
ROTARY_INV_FREQ = auto()
|
32 |
+
ROTARY_COS_SIN = auto()
|
33 |
+
ALIBI_SLOPES = auto()
|
34 |
+
RELATIVE_ATTENTION_BIAS = auto()
|
35 |
+
CROSS_QKV = auto()
|
36 |
+
CROSS_QKV_LENGTH = auto()
|
37 |
+
ENCODER_INPUT_LENGTH = auto()
|
38 |
+
HOST_CONTEXT_LENGTH = auto()
|
39 |
+
QKV_BIAS_TENSOR = auto()
|
40 |
+
SPEC_DECODING_PACKED_MASK = auto()
|
41 |
+
SPEC_DECODING_POSITION_OFFSETS = auto()
|
42 |
+
SPEC_DECODING_GENERATION_LENGTHS = auto()
|
43 |
+
HOST_RUNTIME_PERF_KNOBS = auto()
|
44 |
+
|
45 |
+
|
46 |
+
class IdxEntryParser:
|
47 |
+
|
48 |
+
def __init__(self, plugin_info):
|
49 |
+
self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0]
|
50 |
+
self.unfuse_qkv_gemm = bool(
|
51 |
+
plugin_info.pfc_as_list['unfuse_qkv_gemm'][0])
|
52 |
+
self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0])
|
53 |
+
self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0])
|
54 |
+
self.do_cross_attention = bool(
|
55 |
+
plugin_info.pfc_as_list['do_cross_attention'][0])
|
56 |
+
self.remove_input_padding = bool(
|
57 |
+
plugin_info.pfc_as_list['remove_input_padding'][0])
|
58 |
+
self.qkv_bias_enabled = bool(
|
59 |
+
plugin_info.pfc_as_list['qkv_bias_enabled'][0])
|
60 |
+
self.kv_cache_quant_mode = QuantMode(
|
61 |
+
plugin_info.pfc_as_list['kv_cache_quant_mode'][0])
|
62 |
+
self.position_embedding_type = PositionEmbeddingType(
|
63 |
+
plugin_info.pfc_as_list['position_embedding_type'][0])
|
64 |
+
self.is_spec_decoding_enabled = bool(
|
65 |
+
plugin_info.pfc_as_list['is_spec_decoding_enabled'][0])
|
66 |
+
self.init_entry_to_index()
|
67 |
+
|
68 |
+
# WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp
|
69 |
+
def is_entry_used(self, entry: IdxEntry) -> bool:
|
70 |
+
if entry == IdxEntry.QKV_TENSOR:
|
71 |
+
return True
|
72 |
+
elif entry == IdxEntry.K_TENSOR:
|
73 |
+
return self.unfuse_qkv_gemm
|
74 |
+
elif entry == IdxEntry.V_TENSOR:
|
75 |
+
return self.unfuse_qkv_gemm
|
76 |
+
elif entry == IdxEntry.SEQUENCE_LENGTH:
|
77 |
+
return self.use_cache
|
78 |
+
elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS:
|
79 |
+
return self.use_cache
|
80 |
+
elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW:
|
81 |
+
return True
|
82 |
+
elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH:
|
83 |
+
return True
|
84 |
+
elif entry == IdxEntry.CONTEXT_LENGTHS:
|
85 |
+
return True
|
86 |
+
elif entry == IdxEntry.CACHE_INDIR:
|
87 |
+
return self.use_cache
|
88 |
+
elif entry == IdxEntry.REQUEST_TYPES:
|
89 |
+
return True
|
90 |
+
elif entry == IdxEntry.KV_CACHE_BLOCK_OFFSETS:
|
91 |
+
return self.use_cache and self.paged_kv_cache
|
92 |
+
elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_OFFSETS:
|
93 |
+
return self.use_cache and self.paged_kv_cache
|
94 |
+
elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS:
|
95 |
+
return self.use_cache and self.paged_kv_cache
|
96 |
+
elif entry == IdxEntry.PAST_KEY_VALUE:
|
97 |
+
return self.use_cache and not self.paged_kv_cache
|
98 |
+
elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE:
|
99 |
+
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
|
100 |
+
)
|
101 |
+
elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE:
|
102 |
+
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
|
103 |
+
)
|
104 |
+
elif entry == IdxEntry.ROTARY_INV_FREQ:
|
105 |
+
return self.position_embedding_type.is_rope()
|
106 |
+
elif entry == IdxEntry.ROTARY_COS_SIN:
|
107 |
+
return self.position_embedding_type.is_rope()
|
108 |
+
elif entry == IdxEntry.ALIBI_SLOPES:
|
109 |
+
return self.position_embedding_type.is_alibi()
|
110 |
+
elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS:
|
111 |
+
return self.position_embedding_type == PositionEmbeddingType.relative
|
112 |
+
elif entry == IdxEntry.CROSS_QKV:
|
113 |
+
return self.do_cross_attention
|
114 |
+
elif entry == IdxEntry.CROSS_QKV_LENGTH:
|
115 |
+
return self.do_cross_attention
|
116 |
+
elif entry == IdxEntry.ENCODER_INPUT_LENGTH:
|
117 |
+
return self.do_cross_attention
|
118 |
+
elif entry == IdxEntry.HOST_CONTEXT_LENGTH:
|
119 |
+
return self.remove_input_padding
|
120 |
+
elif entry == IdxEntry.QKV_BIAS_TENSOR:
|
121 |
+
return self.qkv_bias_enabled
|
122 |
+
elif entry == IdxEntry.SPEC_DECODING_PACKED_MASK:
|
123 |
+
return self.is_spec_decoding_enabled
|
124 |
+
elif entry == IdxEntry.SPEC_DECODING_POSITION_OFFSETS:
|
125 |
+
return self.is_spec_decoding_enabled
|
126 |
+
elif entry == IdxEntry.SPEC_DECODING_GENERATION_LENGTHS:
|
127 |
+
return self.is_spec_decoding_enabled
|
128 |
+
elif entry == IdxEntry.HOST_RUNTIME_PERF_KNOBS:
|
129 |
+
return True
|
130 |
+
else:
|
131 |
+
return False
|
132 |
+
|
133 |
+
def init_entry_to_index(self):
|
134 |
+
self.entry_to_index = {}
|
135 |
+
index = 0
|
136 |
+
for entry in IdxEntry:
|
137 |
+
if self.is_entry_used(entry):
|
138 |
+
self.entry_to_index[entry] = index
|
139 |
+
index += 1
|
140 |
+
|
141 |
+
def get_index(self, entry: IdxEntry) -> int:
|
142 |
+
if entry not in self.entry_to_index:
|
143 |
+
raise Exception(
|
144 |
+
f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}"
|
145 |
+
)
|
146 |
+
return self.entry_to_index[entry]
|
147 |
+
|
148 |
+
|
149 |
+
def get_partition(device_dim, device_ids):
|
150 |
+
if device_dim == [0]:
|
151 |
+
partition = device_ids.shape[0]
|
152 |
+
elif device_dim == [1]:
|
153 |
+
partition = device_ids.shape[1]
|
154 |
+
else:
|
155 |
+
assert device_dim == [0, 1] or device_dim == [1, 0]
|
156 |
+
partition = device_ids.size
|
157 |
+
return partition
|
158 |
+
|
159 |
+
|
160 |
+
class GPTAttentionPlugin(PluginNode):
|
161 |
+
|
162 |
+
def __init__(self, layer):
|
163 |
+
super().__init__(layer)
|
164 |
+
self.parser = IdxEntryParser(self.plugin_info)
|
165 |
+
assert self.num_inputs == len(
|
166 |
+
self.parser.entry_to_index
|
167 |
+
), f'the number of plugin inputs ({self.num_inputs}) is invalid'
|
168 |
+
assert self.num_outputs == (
|
169 |
+
2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else 1
|
170 |
+
), f'the number of plugin outputs ({self.num_outputs}) has been changed'
|
171 |
+
|
172 |
+
def _tp_strategy(self, device_mesh):
|
173 |
+
strategies_vector = StrategiesVector(self)
|
174 |
+
head_dim = 1 if self.parser.remove_input_padding else 2
|
175 |
+
# TODO: allow mesh_dim = [0] or [1]
|
176 |
+
# for mesh_dim in ([0], [1], [0, 1]):
|
177 |
+
for mesh_dim in ([0, 1], ):
|
178 |
+
if self.parser.num_kv_heads != 1:
|
179 |
+
# MHA or GQA
|
180 |
+
# TODO: allow to duplicate kv when #kv_head < #partition
|
181 |
+
q_pdict = {
|
182 |
+
head_dim: mesh_dim
|
183 |
+
} # split in heads/hidden dimension
|
184 |
+
k_pdict = {
|
185 |
+
head_dim: mesh_dim
|
186 |
+
} # split in heads/hidden dimension
|
187 |
+
v_pdict = {
|
188 |
+
head_dim: mesh_dim
|
189 |
+
} # split in heads/hidden dimension
|
190 |
+
pastkv_pdict = {2: mesh_dim} # split in heads dimension
|
191 |
+
present_kv_pdict = {2: mesh_dim} # split in heads dimension
|
192 |
+
else:
|
193 |
+
# MQA
|
194 |
+
q_pdict = {
|
195 |
+
head_dim: mesh_dim
|
196 |
+
} # split in heads/hidden dimension
|
197 |
+
k_pdict = {} # RR
|
198 |
+
v_pdict = {} # RR
|
199 |
+
pastkv_pdict = {} # RR
|
200 |
+
present_kv_pdict = {} # RR
|
201 |
+
|
202 |
+
out0_pdict = {head_dim: mesh_dim}
|
203 |
+
|
204 |
+
dim_partition_dict_mapping = {
|
205 |
+
f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict,
|
206 |
+
f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict,
|
207 |
+
f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict,
|
208 |
+
'output0': out0_pdict,
|
209 |
+
}
|
210 |
+
if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE):
|
211 |
+
dim_partition_dict_mapping[
|
212 |
+
f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict
|
213 |
+
dim_partition_dict_mapping['output1'] = present_kv_pdict
|
214 |
+
for i in range(self.num_inputs):
|
215 |
+
if f'input{i}' not in dim_partition_dict_mapping:
|
216 |
+
dim_partition_dict_mapping[f'input{i}'] = {}
|
217 |
+
for i in range(self.num_outputs):
|
218 |
+
if f'output{i}' not in dim_partition_dict_mapping:
|
219 |
+
dim_partition_dict_mapping[f'output{i}'] = {}
|
220 |
+
|
221 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
222 |
+
dim_partition_dict_mapping, device_mesh)
|
223 |
+
if 0 == len(sharding_spec_mapping):
|
224 |
+
continue
|
225 |
+
name = 'gptAttentionPlugin_tp_strategy'
|
226 |
+
sharding_strategy = self._get_sharding_strategy(
|
227 |
+
name=name,
|
228 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
229 |
+
communication_action_mapping={})
|
230 |
+
strategies_vector.append(sharding_strategy)
|
231 |
+
return strategies_vector
|
232 |
+
|
233 |
+
def _dp_strategy(self, device_mesh):
|
234 |
+
strategies_vector = StrategiesVector(self)
|
235 |
+
for mesh_dim in ([0], [1], [0, 1]):
|
236 |
+
dim_partition_dict_mapping = {}
|
237 |
+
for i in range(self.num_inputs):
|
238 |
+
dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim}
|
239 |
+
for i in range(self.num_outputs):
|
240 |
+
dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim}
|
241 |
+
|
242 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
243 |
+
dim_partition_dict_mapping, device_mesh)
|
244 |
+
if 0 == len(sharding_spec_mapping):
|
245 |
+
continue
|
246 |
+
name = 'gptAttentionPlugin_dp_strategy'
|
247 |
+
sharding_strategy = self._get_sharding_strategy(
|
248 |
+
name=name,
|
249 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
250 |
+
communication_action_mapping={})
|
251 |
+
strategies_vector.append(sharding_strategy)
|
252 |
+
return strategies_vector
|
253 |
+
|
254 |
+
def _collect_strategies(self, device_mesh):
|
255 |
+
if device_mesh.size == 1:
|
256 |
+
default_strategies = self._default_strategy(device_mesh)
|
257 |
+
else:
|
258 |
+
# Avoid to use all-replicate strategy for mesh size > 1
|
259 |
+
# since the CPP runtime does not support it for gpt attention plugin
|
260 |
+
default_strategies = StrategiesVector(self)
|
261 |
+
for idx, strategy in enumerate(default_strategies):
|
262 |
+
strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}'
|
263 |
+
if self.parser.unfuse_qkv_gemm:
|
264 |
+
tp_strategies = self._tp_strategy(device_mesh)
|
265 |
+
default_strategies.extend(tp_strategies)
|
266 |
+
# if we don't split the batch dim, it should be default strategis
|
267 |
+
# elif we split the batch dim, it should be dp_strategies
|
268 |
+
# we can use above information to distinguish the two kinds of strategy
|
269 |
+
if not self.parser.remove_input_padding:
|
270 |
+
dp_strategies = self._dp_strategy(device_mesh)
|
271 |
+
default_strategies.extend(dp_strategies)
|
272 |
+
return default_strategies
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def parameter_generator(sharding_specs, plugin_info):
|
276 |
+
|
277 |
+
def get_shape(entry):
|
278 |
+
return sharding_specs[
|
279 |
+
f'input{parser.get_index(entry)}'].get_sharded_shape_per_device(
|
280 |
+
)
|
281 |
+
|
282 |
+
parser = IdxEntryParser(plugin_info)
|
283 |
+
updated_input_values = {}
|
284 |
+
batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0]
|
285 |
+
if parser.use_cache:
|
286 |
+
beams_width = get_shape(IdxEntry.CACHE_INDIR)[1]
|
287 |
+
max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2]
|
288 |
+
elif not parser.remove_input_padding:
|
289 |
+
max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1]
|
290 |
+
else:
|
291 |
+
max_seq_length = 1
|
292 |
+
host_request_types = torch.full(
|
293 |
+
(batch_size, ),
|
294 |
+
1,
|
295 |
+
dtype=torch.int32,
|
296 |
+
device='cpu',
|
297 |
+
)
|
298 |
+
updated_input_values[parser.get_index(
|
299 |
+
IdxEntry.REQUEST_TYPES)] = host_request_types
|
300 |
+
context_lengths = torch.full(
|
301 |
+
(batch_size, ),
|
302 |
+
max_seq_length - 1,
|
303 |
+
dtype=torch.int32,
|
304 |
+
device=torch.cuda.current_device(),
|
305 |
+
)
|
306 |
+
updated_input_values[parser.get_index(
|
307 |
+
IdxEntry.CONTEXT_LENGTHS)] = context_lengths
|
308 |
+
host_max_attention_window_sizes = torch.tensor(
|
309 |
+
[max_seq_length],
|
310 |
+
dtype=torch.int32,
|
311 |
+
device='cpu',
|
312 |
+
)
|
313 |
+
updated_input_values[parser.get_index(
|
314 |
+
IdxEntry.HOST_MAX_ATTENTION_WINDOW
|
315 |
+
)] = host_max_attention_window_sizes
|
316 |
+
host_sink_token_length = torch.tensor(
|
317 |
+
[0],
|
318 |
+
dtype=torch.int32,
|
319 |
+
device='cpu',
|
320 |
+
)
|
321 |
+
updated_input_values[parser.get_index(
|
322 |
+
IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length
|
323 |
+
if parser.use_cache:
|
324 |
+
sequence_length = torch.full((batch_size, ),
|
325 |
+
max_seq_length,
|
326 |
+
dtype=torch.int32,
|
327 |
+
device=torch.cuda.current_device())
|
328 |
+
updated_input_values[parser.get_index(
|
329 |
+
IdxEntry.SEQUENCE_LENGTH)] = sequence_length
|
330 |
+
host_past_key_value_length = torch.full((batch_size, ),
|
331 |
+
max_seq_length - 1,
|
332 |
+
dtype=torch.int32,
|
333 |
+
device='cpu')
|
334 |
+
updated_input_values[parser.get_index(
|
335 |
+
IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS
|
336 |
+
)] = host_past_key_value_length
|
337 |
+
cache_indirections = torch.full(
|
338 |
+
(batch_size, beams_width, max_seq_length),
|
339 |
+
0,
|
340 |
+
dtype=torch.int32,
|
341 |
+
device=torch.cuda.current_device())
|
342 |
+
updated_input_values[parser.get_index(
|
343 |
+
IdxEntry.CACHE_INDIR)] = cache_indirections
|
344 |
+
if parser.remove_input_padding:
|
345 |
+
host_context_lengths = torch.full(get_shape(
|
346 |
+
IdxEntry.HOST_CONTEXT_LENGTH),
|
347 |
+
max_seq_length - 1,
|
348 |
+
dtype=torch.int32,
|
349 |
+
device='cpu')
|
350 |
+
updated_input_values[parser.get_index(
|
351 |
+
IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths
|
352 |
+
return updated_input_values
|
353 |
+
|
354 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
355 |
+
sharding_spec = strategy.sharding_specs[
|
356 |
+
f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"]
|
357 |
+
shard_dims = sharding_spec.dim_partition_dict
|
358 |
+
device_ids = device_mesh.phy_ids
|
359 |
+
if 2 in shard_dims:
|
360 |
+
device_dim = shard_dims[2]
|
361 |
+
partition = get_partition(device_dim, device_ids)
|
362 |
+
else:
|
363 |
+
partition = 1
|
364 |
+
if self.parser.is_entry_used(IdxEntry.K_TENSOR):
|
365 |
+
kv_sharding_spec = strategy.sharding_specs[
|
366 |
+
f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"]
|
367 |
+
kv_shard_dims = kv_sharding_spec.dim_partition_dict
|
368 |
+
if 2 in kv_shard_dims:
|
369 |
+
kv_device_dim = kv_shard_dims[2]
|
370 |
+
kv_partition = get_partition(kv_device_dim, device_ids)
|
371 |
+
else:
|
372 |
+
kv_partition = 1
|
373 |
+
else:
|
374 |
+
kv_partition = 1
|
375 |
+
num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy()
|
376 |
+
num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
|
377 |
+
tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy()
|
378 |
+
tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy()
|
379 |
+
num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
|
380 |
+
num_heads = np.maximum(num_heads // partition, 1)
|
381 |
+
tp_size[0] = partition
|
382 |
+
tp_rank[0] = 0
|
383 |
+
|
384 |
+
updated_layer_attrs = {
|
385 |
+
'tp_size': tp_size,
|
386 |
+
'tp_rank': tp_rank,
|
387 |
+
'num_heads': num_heads,
|
388 |
+
'num_kv_heads': num_kv_heads
|
389 |
+
}
|
390 |
+
updated_input_values = self.parameter_generator(strategy.sharding_specs,
|
391 |
+
self.plugin_info)
|
392 |
+
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
393 |
+
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
394 |
+
device_mesh)
|
395 |
+
return elapsed_time
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..identity_node import Identity
|
2 |
+
from ..plugin_node import PluginNode
|
3 |
+
|
4 |
+
|
5 |
+
class IdentityPlugin(Identity, PluginNode):
|
6 |
+
|
7 |
+
def __init__(self, layer):
|
8 |
+
PluginNode.__init__(self, layer)
|
9 |
+
|
10 |
+
def _collect_strategies(self, device_mesh):
|
11 |
+
return Identity._collect_strategies(self, device_mesh)
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorrt as trt
|
2 |
+
|
3 |
+
from ..gather_node import Gather
|
4 |
+
from ..plugin_node import PluginNode
|
5 |
+
|
6 |
+
|
7 |
+
class LookupPlugin(Gather, PluginNode):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
PluginNode.__init__(self, layer)
|
11 |
+
self.mode = trt.GatherMode.DEFAULT
|
12 |
+
self.axis = 0
|
13 |
+
self.num_elementwise_dims = 0
|
14 |
+
self.input_id = 1
|
15 |
+
self.indice_id = 0
|
16 |
+
self.support_vocab_tp = True
|
17 |
+
|
18 |
+
def _collect_strategies(self, device_mesh):
|
19 |
+
return Gather._collect_strategies(self, device_mesh)
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..normalization_node import Normalization
|
2 |
+
from ..plugin_node import PluginNode
|
3 |
+
|
4 |
+
|
5 |
+
class LayernormPlugin(Normalization, PluginNode):
|
6 |
+
|
7 |
+
def __init__(self, layer):
|
8 |
+
PluginNode.__init__(self, layer)
|
9 |
+
# the is only true for llm model, because layer norm is only effect on hidden dim
|
10 |
+
hidden_dim = len(self.op_data['input0'].shape) - 1
|
11 |
+
self.axes = 1 << hidden_dim
|
12 |
+
self.weight_bias_dim_base = hidden_dim
|
13 |
+
|
14 |
+
def _collect_strategies(self, device_mesh):
|
15 |
+
return Normalization._collect_strategies(self, device_mesh)
|
16 |
+
|
17 |
+
|
18 |
+
class RMSnormPlugin(Normalization, PluginNode):
|
19 |
+
|
20 |
+
def __init__(self, layer):
|
21 |
+
PluginNode.__init__(self, layer)
|
22 |
+
# the is only true for llm model, because rms norm is only effect on hidden dim
|
23 |
+
hidden_dim = len(self.op_data['input0'].shape) - 1
|
24 |
+
self.axes = 1 << hidden_dim
|
25 |
+
self.weight_bias_dim_base = hidden_dim
|
26 |
+
|
27 |
+
def _collect_strategies(self, device_mesh):
|
28 |
+
return Normalization._collect_strategies(self, device_mesh)
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorrt_llm._utils import trt_axes_to_dim
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Reduce(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes)
|
13 |
+
self.sum_mapping_dict = {}
|
14 |
+
num_input_dims = len(self.get_input(0).shape)
|
15 |
+
if layer.as_trt().keep_dims:
|
16 |
+
for i in range(num_input_dims):
|
17 |
+
self.sum_mapping_dict[i] = i
|
18 |
+
else:
|
19 |
+
output_index = 0
|
20 |
+
for i in range(num_input_dims):
|
21 |
+
if i not in self.reduce_dims:
|
22 |
+
self.sum_mapping_dict[i] = output_index
|
23 |
+
output_index += 1
|
24 |
+
assert output_index == len(self.get_output(0).shape)
|
25 |
+
layer.to_base_class()
|
26 |
+
|
27 |
+
def _collect_strategies(self, device_mesh):
|
28 |
+
dim_partition_list = []
|
29 |
+
dim_size = len(self.op_data['input0'].shape)
|
30 |
+
dim_partition_list.append({})
|
31 |
+
dim_partition_list.extend(
|
32 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
33 |
+
dim_partition_list.extend(
|
34 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
35 |
+
dim_partition_list.extend(
|
36 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
37 |
+
dim_partition_list.extend(
|
38 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
39 |
+
strategies_vector = StrategiesVector(self)
|
40 |
+
for dim_partition_dict in dim_partition_list:
|
41 |
+
recover_dims = []
|
42 |
+
out_partition_dict = {}
|
43 |
+
for dim in dim_partition_dict.keys():
|
44 |
+
if dim in self.reduce_dims:
|
45 |
+
recover_dims.append(dim)
|
46 |
+
elif dim in self.sum_mapping_dict:
|
47 |
+
out_partition_dict[
|
48 |
+
self.sum_mapping_dict[dim]] = dim_partition_dict[dim]
|
49 |
+
else:
|
50 |
+
assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict'
|
51 |
+
|
52 |
+
for dim in recover_dims:
|
53 |
+
dim_partition_dict.pop(dim)
|
54 |
+
|
55 |
+
in0_parition_dict = dim_partition_dict
|
56 |
+
dim_partition_dict_mapping = {
|
57 |
+
"input0": in0_parition_dict,
|
58 |
+
"output0": out_partition_dict,
|
59 |
+
}
|
60 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
61 |
+
dim_partition_dict_mapping, device_mesh)
|
62 |
+
if 0 == len(sharding_spec_mapping):
|
63 |
+
continue
|
64 |
+
name = '{} = <reduce along dim {}> {}'.format(
|
65 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
66 |
+
self.reduce_dims,
|
67 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
68 |
+
sharding_strategy = self._get_sharding_strategy(
|
69 |
+
name=name,
|
70 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
71 |
+
communication_action_mapping={})
|
72 |
+
strategies_vector.append(sharding_strategy)
|
73 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import Node
|
2 |
+
from .sharding_strategy import StrategiesVector
|
3 |
+
|
4 |
+
|
5 |
+
class Select(Node):
|
6 |
+
|
7 |
+
def __init__(self, layer):
|
8 |
+
super().__init__(layer)
|
9 |
+
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
10 |
+
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
11 |
+
|
12 |
+
def _collect_strategies(self, device_mesh):
|
13 |
+
dim_partition_list = []
|
14 |
+
dim_size = len(self.op_data['output0'].shape)
|
15 |
+
dim_partition_list.append({})
|
16 |
+
dim_partition_list.extend(
|
17 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
18 |
+
dim_partition_list.extend(
|
19 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
20 |
+
dim_partition_list.extend(
|
21 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
22 |
+
dim_partition_list.extend(
|
23 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
24 |
+
|
25 |
+
strategies_vector = StrategiesVector(self)
|
26 |
+
for dim_partition_dict in dim_partition_list:
|
27 |
+
# the three inputs are condition, true tensor and false tensor
|
28 |
+
in0_partition_dict = self._recover_bcast_partition_dict(
|
29 |
+
dim_partition_dict, self.op_data['input0'])
|
30 |
+
in1_partition_dict = self._recover_bcast_partition_dict(
|
31 |
+
dim_partition_dict, self.op_data['input1'])
|
32 |
+
in2_partition_dict = self._recover_bcast_partition_dict(
|
33 |
+
dim_partition_dict, self.op_data['input2'])
|
34 |
+
out_partition_dict = dim_partition_dict
|
35 |
+
dim_partition_dict_mapping = {
|
36 |
+
"input0": in0_partition_dict,
|
37 |
+
"input1": in1_partition_dict,
|
38 |
+
"input2": in2_partition_dict,
|
39 |
+
"output0": out_partition_dict,
|
40 |
+
}
|
41 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
42 |
+
dim_partition_dict_mapping, device_mesh)
|
43 |
+
if 0 == len(sharding_spec_mapping):
|
44 |
+
continue
|
45 |
+
name = '{} = <select op {}> {} {}'.format(
|
46 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
47 |
+
sharding_spec_mapping['input0'].sharding_sequence,
|
48 |
+
sharding_spec_mapping['input1'].sharding_sequence,
|
49 |
+
sharding_spec_mapping['input2'].sharding_sequence)
|
50 |
+
|
51 |
+
sharding_strategy = self._get_sharding_strategy(
|
52 |
+
name=name,
|
53 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
54 |
+
communication_action_mapping={})
|
55 |
+
strategies_vector.append(sharding_strategy)
|
56 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_consistency.py
ADDED
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import operator
|
3 |
+
from functools import reduce
|
4 |
+
from typing import List, Tuple
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from .comm_spec import CommSpec
|
9 |
+
from .sharding_spec import ShardingSpec
|
10 |
+
|
11 |
+
|
12 |
+
class ShapeConsistencyManager(object):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
self.forward_only = True
|
16 |
+
self.cached_spec_pairs_transform_path = {}
|
17 |
+
self.cache_hit = 0
|
18 |
+
self.cache_miss = 0
|
19 |
+
|
20 |
+
def all_gather_simulator(self, target_pair):
|
21 |
+
_, shard_list = target_pair
|
22 |
+
new_shard_list = []
|
23 |
+
return new_shard_list
|
24 |
+
|
25 |
+
def all_to_all_simulator(self, f_target_pair, b_target_pair):
|
26 |
+
'''
|
27 |
+
Simulating all-to-all operation, analyze the communication cost
|
28 |
+
and simulate the influence of the DimSpec.
|
29 |
+
|
30 |
+
We BANNED all representations which shard_list in decreasing order,
|
31 |
+
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
32 |
+
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
33 |
+
Argument:
|
34 |
+
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
35 |
+
and the second element describes which logical axis will be sharded in that dimension.
|
36 |
+
e.g.:
|
37 |
+
all-to-all(S0, S1) -> [S01, R]
|
38 |
+
all-to-all(S0, R) -> [R, S0]
|
39 |
+
Otherwise, we extend the front shard_list to behind.
|
40 |
+
e.g.:
|
41 |
+
all-to-all(R, S1) -> [S1, R]
|
42 |
+
|
43 |
+
Argument:
|
44 |
+
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
45 |
+
and the second element describes which logical axis will be sharded in that dimension.
|
46 |
+
'''
|
47 |
+
_, f_shard_list = f_target_pair
|
48 |
+
_, b_shard_list = b_target_pair
|
49 |
+
if not len(b_shard_list):
|
50 |
+
b_shard_list.extend(f_shard_list)
|
51 |
+
f_shard_list = []
|
52 |
+
else:
|
53 |
+
f_shard_list.extend(b_shard_list)
|
54 |
+
b_shard_list = []
|
55 |
+
|
56 |
+
return f_shard_list, b_shard_list
|
57 |
+
|
58 |
+
def shard_simulator(self, target_pair, legal_sharding_dims):
|
59 |
+
'''
|
60 |
+
Simulating shard operation, analyze the communication cost(always ZERO)
|
61 |
+
and simulate the influence of the DimSpec.
|
62 |
+
|
63 |
+
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
64 |
+
In addition, We BANNED all representations which shard_list in decreasing order,
|
65 |
+
such as S10, so shard(S0) -> S10 is NOT allowed.
|
66 |
+
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
67 |
+
e.g.:
|
68 |
+
shard(R) -> S0
|
69 |
+
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
|
70 |
+
e.g:
|
71 |
+
shard(S0) -> S01
|
72 |
+
|
73 |
+
Argument:
|
74 |
+
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
75 |
+
and the second element describes which logical axis will be sharded in that dimension.
|
76 |
+
'''
|
77 |
+
_, shard_list = target_pair
|
78 |
+
shard_list_list, logical_process_axis = [], []
|
79 |
+
for dim in legal_sharding_dims:
|
80 |
+
if len(shard_list) != 0 and dim <= shard_list[-1]:
|
81 |
+
continue
|
82 |
+
new_shard_list = shard_list + [dim]
|
83 |
+
shard_list_list.append(new_shard_list)
|
84 |
+
logical_process_axis.append([dim])
|
85 |
+
|
86 |
+
# we support sorted 2D mesh here
|
87 |
+
if len(legal_sharding_dims) == 2 and len(shard_list) == 0:
|
88 |
+
shard_list_list.append(legal_sharding_dims)
|
89 |
+
logical_process_axis.append(legal_sharding_dims)
|
90 |
+
return shard_list_list, logical_process_axis
|
91 |
+
|
92 |
+
def mix_gather_simulator(self, f_target_pair, b_target_pair):
|
93 |
+
'''
|
94 |
+
Assume index of f and b target pairs are 'f' and 'b'
|
95 |
+
S0S1 => Input: (f, [0]), (b, [1]) Output: [f, b], [[0], [1]]
|
96 |
+
S1S0 => Input: (f, [1]), (b, [0]) Output: [f, b], [[1], [0]]
|
97 |
+
S01R => Input: (f, [0, 1]), (b, []) Output: [f], [[0, 1]]
|
98 |
+
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], [[0, 1]]
|
99 |
+
'''
|
100 |
+
if f_target_pair[1] and b_target_pair[1]:
|
101 |
+
return [f_target_pair[0],
|
102 |
+
b_target_pair[0]], [f_target_pair[1], b_target_pair[1]]
|
103 |
+
if f_target_pair[1]:
|
104 |
+
return [f_target_pair[0]], [f_target_pair[1]]
|
105 |
+
if b_target_pair[1]:
|
106 |
+
return [b_target_pair[0]], [b_target_pair[1]]
|
107 |
+
|
108 |
+
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
109 |
+
'''
|
110 |
+
Get all valid sharding specs from source_spec with single all-gather operation, and
|
111 |
+
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
112 |
+
For the all-gather operation, we just care about the S dimension.
|
113 |
+
|
114 |
+
Argument:
|
115 |
+
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
116 |
+
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
117 |
+
|
118 |
+
Return:
|
119 |
+
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
|
120 |
+
|
121 |
+
Example:
|
122 |
+
dim_partition_dict = {0: [0], 1: [1]}
|
123 |
+
# DistSpec:
|
124 |
+
# shard_sequence: S0,S1,R
|
125 |
+
# device_mesh_shape: (4, 4)
|
126 |
+
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
127 |
+
shape_consistency_manager = ShapeConsistencyManager()
|
128 |
+
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
129 |
+
print(rst_dict)
|
130 |
+
|
131 |
+
Output:
|
132 |
+
{DistSpec:
|
133 |
+
shard_sequence: R,S1,R
|
134 |
+
device_mesh_shape: (4, 4): 0, DistSpec:
|
135 |
+
shard_sequence: S0,R,R
|
136 |
+
device_mesh_shape: (4, 4): 0}
|
137 |
+
'''
|
138 |
+
valid_spec_dict = {}
|
139 |
+
comm_pattern = 'all_gather'
|
140 |
+
for target_pair in source_spec.dim_partition_dict.items():
|
141 |
+
shard_list = self.all_gather_simulator(target_pair)
|
142 |
+
index = target_pair[0]
|
143 |
+
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
144 |
+
|
145 |
+
# We won't add empty list into dim_partition_dict
|
146 |
+
# The key will be popped if the related shard_list is empty
|
147 |
+
if shard_list:
|
148 |
+
new_dim_partition_dict[index] = shard_list
|
149 |
+
else:
|
150 |
+
new_dim_partition_dict.pop(index)
|
151 |
+
|
152 |
+
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
153 |
+
gather_dim = index
|
154 |
+
logical_process_axis = target_pair[1]
|
155 |
+
comm_spec = CommSpec(comm_pattern,
|
156 |
+
sharding_spec=source_spec,
|
157 |
+
gather_dim=[gather_dim],
|
158 |
+
logical_process_axis=[logical_process_axis],
|
159 |
+
forward_only=self.forward_only)
|
160 |
+
|
161 |
+
# compute the communication cost with CommSpec
|
162 |
+
|
163 |
+
# generate new sharding spec
|
164 |
+
new_sharding_spec = ShardingSpec(
|
165 |
+
source_spec.device_mesh,
|
166 |
+
source_spec.data_type_size,
|
167 |
+
source_spec.entire_shape,
|
168 |
+
source_spec.max_entire_shape,
|
169 |
+
source_spec.raw_shape,
|
170 |
+
dim_partition_dict=new_dim_partition_dict)
|
171 |
+
|
172 |
+
if not new_sharding_spec.sanity_check():
|
173 |
+
continue
|
174 |
+
cost = comm_spec.get_comm_cost()
|
175 |
+
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
|
176 |
+
return valid_spec_dict
|
177 |
+
|
178 |
+
def get_all_all_to_all_spec(self, source_spec, orig_cost):
|
179 |
+
'''
|
180 |
+
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
181 |
+
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
182 |
+
For the all-to-all operation, we just care about the pairs containing S dimension.
|
183 |
+
|
184 |
+
Argument:
|
185 |
+
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
186 |
+
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
187 |
+
|
188 |
+
Return:
|
189 |
+
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
190 |
+
|
191 |
+
Example:
|
192 |
+
dim_partition_dict = {0: [0], 1: [1]}
|
193 |
+
# DistSpec:
|
194 |
+
# shard_sequence: S0,S1,R
|
195 |
+
# device_mesh_shape: (4, 4)
|
196 |
+
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
197 |
+
shape_consistency_manager = ShapeConsistencyManager()
|
198 |
+
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
199 |
+
print(rst_dict)
|
200 |
+
|
201 |
+
Output:
|
202 |
+
{DistSpec:
|
203 |
+
shard_sequence: S01,R,R
|
204 |
+
device_mesh_shape: (4, 4): 0, DistSpec:
|
205 |
+
shard_sequence: R,S1,S0
|
206 |
+
device_mesh_shape: (4, 4): 0, DistSpec:
|
207 |
+
shard_sequence: S0,R,S1
|
208 |
+
device_mesh_shape: (4, 4): 0}
|
209 |
+
'''
|
210 |
+
valid_spec_dict = {}
|
211 |
+
comm_pattern = 'all_to_all'
|
212 |
+
tensor_dims = len(source_spec.entire_shape)
|
213 |
+
for f_index in range(tensor_dims - 1):
|
214 |
+
for b_index in range(f_index + 1, tensor_dims):
|
215 |
+
# skip (R, R) cases
|
216 |
+
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
|
217 |
+
continue
|
218 |
+
else:
|
219 |
+
if f_index in source_spec.dim_partition_dict:
|
220 |
+
'''
|
221 |
+
# skip (S01, R) -> (R, S01) is NOT allowed
|
222 |
+
if len(source_spec.dim_partition_dict[f_index]) >= 2:
|
223 |
+
continue
|
224 |
+
'''
|
225 |
+
f_target_pair = (f_index, [
|
226 |
+
*source_spec.dim_partition_dict[f_index]
|
227 |
+
])
|
228 |
+
else:
|
229 |
+
f_target_pair = (f_index, [])
|
230 |
+
if b_index in source_spec.dim_partition_dict:
|
231 |
+
'''
|
232 |
+
# skip (R, S01) -> (S01, R) is NOT allowed
|
233 |
+
if len(source_spec.dim_partition_dict[b_index]) >= 2:
|
234 |
+
continue
|
235 |
+
'''
|
236 |
+
b_target_pair = (b_index, [
|
237 |
+
*source_spec.dim_partition_dict[b_index]
|
238 |
+
])
|
239 |
+
else:
|
240 |
+
b_target_pair = (b_index, [])
|
241 |
+
|
242 |
+
# skip (S1, S0) -> S10
|
243 |
+
if f_target_pair[1] and b_target_pair[
|
244 |
+
1] and f_target_pair[1][0] >= b_target_pair[1][0]:
|
245 |
+
continue
|
246 |
+
f_shard_list, b_shard_list = self.all_to_all_simulator(
|
247 |
+
f_target_pair, b_target_pair)
|
248 |
+
f_index = f_target_pair[0]
|
249 |
+
b_index = b_target_pair[0]
|
250 |
+
|
251 |
+
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
252 |
+
if len(f_shard_list) < len(f_target_pair[1]):
|
253 |
+
gather_dim = f_index
|
254 |
+
shard_dim = b_index
|
255 |
+
logical_process_axis = f_target_pair[1]
|
256 |
+
else:
|
257 |
+
gather_dim = b_index
|
258 |
+
shard_dim = f_index
|
259 |
+
logical_process_axis = b_target_pair[1]
|
260 |
+
comm_spec = CommSpec(
|
261 |
+
comm_pattern,
|
262 |
+
sharding_spec=source_spec,
|
263 |
+
gather_dim=[gather_dim],
|
264 |
+
shard_dim=[shard_dim],
|
265 |
+
logical_process_axis=[logical_process_axis],
|
266 |
+
forward_only=self.forward_only)
|
267 |
+
|
268 |
+
# compute the communication cost with CommSpec
|
269 |
+
|
270 |
+
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
271 |
+
|
272 |
+
# We won't add empty list into dim_partition_dict
|
273 |
+
# The key will be popped if the related shard_list is empty
|
274 |
+
if f_shard_list:
|
275 |
+
new_dim_partition_dict[f_index] = f_shard_list
|
276 |
+
else:
|
277 |
+
new_dim_partition_dict.pop(f_index)
|
278 |
+
if b_shard_list:
|
279 |
+
new_dim_partition_dict[b_index] = b_shard_list
|
280 |
+
else:
|
281 |
+
new_dim_partition_dict.pop(b_index)
|
282 |
+
|
283 |
+
# generate new sharding spec
|
284 |
+
|
285 |
+
new_sharding_spec = ShardingSpec(
|
286 |
+
source_spec.device_mesh,
|
287 |
+
source_spec.data_type_size,
|
288 |
+
source_spec.entire_shape,
|
289 |
+
source_spec.max_entire_shape,
|
290 |
+
source_spec.raw_shape,
|
291 |
+
dim_partition_dict=new_dim_partition_dict)
|
292 |
+
if not new_sharding_spec.sanity_check():
|
293 |
+
continue
|
294 |
+
cost = comm_spec.get_comm_cost()
|
295 |
+
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
296 |
+
cost + orig_cost)
|
297 |
+
|
298 |
+
return valid_spec_dict
|
299 |
+
|
300 |
+
def get_all_shard_spec(self, source_spec, orig_cost):
|
301 |
+
'''
|
302 |
+
Get all valid sharding specs from source_spec with single shard operation, and
|
303 |
+
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
304 |
+
For the sharding operation, we just care about legal sharding dimensions.
|
305 |
+
|
306 |
+
Argument:
|
307 |
+
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
308 |
+
orig_cost(float): the original communication cost before this operation.
|
309 |
+
|
310 |
+
Return:
|
311 |
+
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
312 |
+
|
313 |
+
Example:
|
314 |
+
dim_partition_dict = {0: [0]}
|
315 |
+
# DistSpec:
|
316 |
+
# shard_sequence: S0,R,R
|
317 |
+
# device_mesh_shape: (4, 4)
|
318 |
+
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
319 |
+
shape_consistency_manager = ShapeConsistencyManager()
|
320 |
+
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
321 |
+
print(rst_dict)
|
322 |
+
|
323 |
+
Output:
|
324 |
+
{DistSpec:
|
325 |
+
shard_sequence: S01,R,R
|
326 |
+
device_mesh_shape: (4, 4): 0, DistSpec:
|
327 |
+
shard_sequence: S0,S1,R
|
328 |
+
device_mesh_shape: (4, 4): 0, DistSpec:
|
329 |
+
shard_sequence: S0,R,S1
|
330 |
+
device_mesh_shape: (4, 4): 0}
|
331 |
+
'''
|
332 |
+
valid_spec_dict = {}
|
333 |
+
comm_pattern = 'split'
|
334 |
+
|
335 |
+
# legal sharding dims means the mesh_id is still available to use.
|
336 |
+
legal_sharding_dims = [
|
337 |
+
i for i in range(len(source_spec.device_mesh.mesh_shape))
|
338 |
+
]
|
339 |
+
for dim, shard_list in source_spec.dim_partition_dict.items():
|
340 |
+
for element in shard_list:
|
341 |
+
legal_sharding_dims.remove(element)
|
342 |
+
if len(legal_sharding_dims) == 0:
|
343 |
+
return valid_spec_dict
|
344 |
+
|
345 |
+
tensor_dims = len(source_spec.entire_shape)
|
346 |
+
|
347 |
+
for index in range(tensor_dims):
|
348 |
+
if index not in source_spec.dim_partition_dict:
|
349 |
+
shard_list_list, logical_process_axes = self.shard_simulator(
|
350 |
+
(index, []), legal_sharding_dims)
|
351 |
+
else:
|
352 |
+
shard_list_list, logical_process_axes = self.shard_simulator(
|
353 |
+
(index, source_spec.dim_partition_dict[index]),
|
354 |
+
legal_sharding_dims)
|
355 |
+
if not shard_list_list:
|
356 |
+
continue
|
357 |
+
for shard_list, logical_process_axis in zip(shard_list_list,
|
358 |
+
logical_process_axes):
|
359 |
+
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
360 |
+
new_dim_partition_dict[index] = shard_list
|
361 |
+
|
362 |
+
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
363 |
+
comm_spec = CommSpec(
|
364 |
+
comm_pattern,
|
365 |
+
sharding_spec=source_spec,
|
366 |
+
shard_dim=[index],
|
367 |
+
logical_process_axis=[logical_process_axis],
|
368 |
+
forward_only=self.forward_only)
|
369 |
+
|
370 |
+
# generate new sharding spec
|
371 |
+
new_sharding_spec = ShardingSpec(
|
372 |
+
source_spec.device_mesh,
|
373 |
+
source_spec.data_type_size,
|
374 |
+
source_spec.entire_shape,
|
375 |
+
source_spec.max_entire_shape,
|
376 |
+
source_spec.raw_shape,
|
377 |
+
dim_partition_dict=new_dim_partition_dict)
|
378 |
+
if not new_sharding_spec.sanity_check():
|
379 |
+
continue
|
380 |
+
# compute the communication cost with CommSpec
|
381 |
+
cost = comm_spec.get_comm_cost()
|
382 |
+
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
383 |
+
cost + orig_cost)
|
384 |
+
|
385 |
+
return valid_spec_dict
|
386 |
+
|
387 |
+
def get_all_mixed_shard_spec(self, source_spec, orig_cost):
|
388 |
+
'''
|
389 |
+
Get all valid sharding specs from source_spec with single shard operation, and
|
390 |
+
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
391 |
+
For the sharding operation, we just care about legal sharding dimensions.
|
392 |
+
'''
|
393 |
+
valid_spec_dict = {}
|
394 |
+
comm_pattern = 'split'
|
395 |
+
|
396 |
+
# legal sharding dims means the mesh_id is still available to use.
|
397 |
+
legal_sharding_dims = [
|
398 |
+
i for i in range(len(source_spec.device_mesh.mesh_shape))
|
399 |
+
]
|
400 |
+
for dim, shard_list in source_spec.dim_partition_dict.items():
|
401 |
+
for element in shard_list:
|
402 |
+
legal_sharding_dims.remove(element)
|
403 |
+
if len(legal_sharding_dims) != 2:
|
404 |
+
return valid_spec_dict
|
405 |
+
|
406 |
+
tensor_dims = len(source_spec.entire_shape)
|
407 |
+
for f_index in range(tensor_dims):
|
408 |
+
for b_index in range(tensor_dims):
|
409 |
+
if f_index != b_index:
|
410 |
+
shard_dims = [f_index, b_index]
|
411 |
+
logical_process_axes = [[legal_sharding_dims[0]],
|
412 |
+
[legal_sharding_dims[1]]]
|
413 |
+
new_dim_partition_dict = source_spec.dim_partition_dict.copy(
|
414 |
+
)
|
415 |
+
new_dim_partition_dict[f_index] = [legal_sharding_dims[0]]
|
416 |
+
new_dim_partition_dict[b_index] = [legal_sharding_dims[1]]
|
417 |
+
comm_spec = CommSpec(
|
418 |
+
comm_pattern,
|
419 |
+
sharding_spec=source_spec,
|
420 |
+
shard_dim=shard_dims,
|
421 |
+
logical_process_axis=logical_process_axes,
|
422 |
+
forward_only=self.forward_only)
|
423 |
+
|
424 |
+
# generate new sharding spec
|
425 |
+
new_sharding_spec = ShardingSpec(
|
426 |
+
source_spec.device_mesh,
|
427 |
+
source_spec.data_type_size,
|
428 |
+
source_spec.entire_shape,
|
429 |
+
source_spec.max_entire_shape,
|
430 |
+
source_spec.raw_shape,
|
431 |
+
dim_partition_dict=new_dim_partition_dict)
|
432 |
+
if not new_sharding_spec.sanity_check():
|
433 |
+
continue
|
434 |
+
cost = comm_spec.get_comm_cost()
|
435 |
+
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
436 |
+
cost + orig_cost)
|
437 |
+
return valid_spec_dict
|
438 |
+
|
439 |
+
def get_all_mix_gather_spec(self, source_spec, orig_cost):
|
440 |
+
'''
|
441 |
+
S0S1 -> RR
|
442 |
+
S1S0 -> RR
|
443 |
+
S01R -> RR
|
444 |
+
RS01 -> RR
|
445 |
+
'''
|
446 |
+
valid_spec_dict = {}
|
447 |
+
comm_pathern = 'all_gather'
|
448 |
+
tensor_dims = len(source_spec.entire_shape)
|
449 |
+
for f_index in range(tensor_dims - 1):
|
450 |
+
for b_index in range(f_index + 1, tensor_dims):
|
451 |
+
if (f_index not in source_spec.dim_partition_dict) and (
|
452 |
+
b_index not in source_spec.dim_partition_dict):
|
453 |
+
continue
|
454 |
+
else:
|
455 |
+
if f_index in source_spec.dim_partition_dict:
|
456 |
+
# skip (S10, R) -> (R, R)
|
457 |
+
'''
|
458 |
+
if len(
|
459 |
+
f_target_pair[1]
|
460 |
+
) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
|
461 |
+
continue
|
462 |
+
'''
|
463 |
+
f_target_pair = (f_index, [
|
464 |
+
*source_spec.dim_partition_dict[f_index]
|
465 |
+
])
|
466 |
+
else:
|
467 |
+
f_target_pair = (f_index, [])
|
468 |
+
if b_index in source_spec.dim_partition_dict:
|
469 |
+
# skip (R, S10) -> (R, R)
|
470 |
+
'''
|
471 |
+
if len(
|
472 |
+
b_target_pair[1]
|
473 |
+
) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
|
474 |
+
continue
|
475 |
+
'''
|
476 |
+
b_target_pair = (b_index, [
|
477 |
+
*source_spec.dim_partition_dict[b_index]
|
478 |
+
])
|
479 |
+
else:
|
480 |
+
b_target_pair = (b_index, [])
|
481 |
+
if len(f_target_pair[1]) + len(b_target_pair[1]) != 2:
|
482 |
+
continue
|
483 |
+
gather_dim, logical_process_axes = self.mix_gather_simulator(
|
484 |
+
f_target_pair, b_target_pair)
|
485 |
+
comm_spec = CommSpec(comm_pathern,
|
486 |
+
sharding_spec=source_spec,
|
487 |
+
gather_dim=gather_dim,
|
488 |
+
logical_process_axis=logical_process_axes,
|
489 |
+
forward_only=self.forward_only,
|
490 |
+
mix_gather=True)
|
491 |
+
|
492 |
+
new_dim_partition_dict = {}
|
493 |
+
# generate new sharding spec
|
494 |
+
new_sharding_spec = ShardingSpec(
|
495 |
+
source_spec.device_mesh,
|
496 |
+
source_spec.data_type_size,
|
497 |
+
source_spec.entire_shape,
|
498 |
+
source_spec.max_entire_shape,
|
499 |
+
source_spec.raw_shape,
|
500 |
+
dim_partition_dict=new_dim_partition_dict)
|
501 |
+
if not new_sharding_spec.sanity_check():
|
502 |
+
continue
|
503 |
+
cost = comm_spec.get_comm_cost()
|
504 |
+
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
505 |
+
cost + orig_cost)
|
506 |
+
|
507 |
+
return valid_spec_dict
|
508 |
+
|
509 |
+
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
|
510 |
+
'''
|
511 |
+
Get all valid sharding specs from source_spec with one step transform, and
|
512 |
+
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
513 |
+
Note:
|
514 |
+
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
515 |
+
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
516 |
+
we could safely put them together.
|
517 |
+
|
518 |
+
Argument:
|
519 |
+
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
520 |
+
orig_cost(float): the original communication cost before this operation.
|
521 |
+
|
522 |
+
Return:
|
523 |
+
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
524 |
+
'''
|
525 |
+
valid_spec_dict = {}
|
526 |
+
valid_spec_dict.update(
|
527 |
+
self.get_all_all_gather_spec(source_spec, orig_cost))
|
528 |
+
valid_spec_dict.update(
|
529 |
+
self.get_all_all_to_all_spec(source_spec, orig_cost))
|
530 |
+
valid_spec_dict.update(
|
531 |
+
self.get_all_mix_gather_spec(source_spec, orig_cost))
|
532 |
+
valid_spec_dict.update(
|
533 |
+
self.get_all_mixed_shard_spec(source_spec, orig_cost))
|
534 |
+
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
|
535 |
+
return valid_spec_dict
|
536 |
+
|
537 |
+
def mem_cost(self, comm_action_sequence: List[CommSpec], mem_pattern='opt'):
|
538 |
+
"""memory cost of the communication action sequence
|
539 |
+
|
540 |
+
Args:
|
541 |
+
comm_action_sequence (List[CommSpec]): list of communication actions
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
TrainCycleItem: memory (numel) cost of such comm_action_sequence
|
545 |
+
"""
|
546 |
+
|
547 |
+
def compute_shape(sharding_spec: ShardingSpec):
|
548 |
+
if 'opt' == mem_pattern:
|
549 |
+
return sharding_spec.get_sharded_shape_per_device()
|
550 |
+
elif 'max' == mem_pattern:
|
551 |
+
return sharding_spec.get_max_sharded_shape_per_device()
|
552 |
+
else:
|
553 |
+
return 0.0
|
554 |
+
|
555 |
+
def gather_analysis(comm_spec, peak_mem):
|
556 |
+
"""analyze all_gather memory footprint
|
557 |
+
all_gather will allocate memory for the output tensor, and there will be temp memory for
|
558 |
+
all_gather operation, which is twice the size of output tensor
|
559 |
+
|
560 |
+
Args:
|
561 |
+
comm_spec (CommSpec): input CommSpec
|
562 |
+
"""
|
563 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
564 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
565 |
+
for axes in comm_spec.logical_process_axis:
|
566 |
+
for axis in axes:
|
567 |
+
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[
|
568 |
+
axis]
|
569 |
+
alloc_mem = (input_numel +
|
570 |
+
output_numel * 2) * comm_spec.sharding_spec.dtype_size
|
571 |
+
peak_mem = max(peak_mem, alloc_mem)
|
572 |
+
return peak_mem
|
573 |
+
|
574 |
+
def reduce_scatter_analysis(comm_spec, peak_mem):
|
575 |
+
|
576 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
577 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
578 |
+
output_numel = input_numel
|
579 |
+
for axes in comm_spec.logical_process_axis:
|
580 |
+
for axis in axes:
|
581 |
+
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
|
582 |
+
axis]
|
583 |
+
alloc_mem = (input_numel +
|
584 |
+
output_numel * 2) * comm_spec.sharding_spec.dtype_size
|
585 |
+
peak_mem = max(peak_mem, alloc_mem)
|
586 |
+
|
587 |
+
return peak_mem
|
588 |
+
|
589 |
+
def split_analysis(comm_spec: CommSpec, peak_mem: int):
|
590 |
+
"""analyze split memory footprint
|
591 |
+
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
592 |
+
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
|
593 |
+
generate new tensor in this case, so no memory will be allocated.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
comm_spec (CommSpec): input CommSpec
|
597 |
+
discard_input (bool): whether to discard the input tensor
|
598 |
+
alloc_numel (int): current allocated numel
|
599 |
+
peak_numel (int): current peak numel
|
600 |
+
"""
|
601 |
+
shard_dim = comm_spec.shard_dim
|
602 |
+
if shard_dim != 0:
|
603 |
+
# if we don't shard the tensor on the first dimension, the split action will
|
604 |
+
# generate a new tensor
|
605 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
606 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
607 |
+
output_numel = input_numel
|
608 |
+
for axes in comm_spec.logical_process_axis:
|
609 |
+
for axis in axes:
|
610 |
+
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
|
611 |
+
axis]
|
612 |
+
alloc_mem = (input_numel +
|
613 |
+
output_numel) * comm_spec.sharding_spec.dtype_size
|
614 |
+
peak_mem = max(peak_mem, alloc_mem)
|
615 |
+
else:
|
616 |
+
# if we shard the tensor on the first dimension, the split action will not generate
|
617 |
+
# a new tensor, and as it will preserve a reference to the input tensor, we could
|
618 |
+
# override the discard_input option here
|
619 |
+
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
|
620 |
+
# actions in the comm actions sequence, the first split action operate on the second dimension,
|
621 |
+
# the second split action operate on the first dimension, and the third split action operate, again,
|
622 |
+
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
|
623 |
+
# memory the same size as the output of first split action. However, the third split action will discard
|
624 |
+
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
|
625 |
+
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
|
626 |
+
# kind of weird, and I think we could ignore it for now.
|
627 |
+
pass
|
628 |
+
return peak_mem
|
629 |
+
|
630 |
+
def reduce_analysis(comm_spec: CommSpec, peak_mem: int):
|
631 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
632 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
633 |
+
output_numel = input_numel
|
634 |
+
alloc_mem = (input_numel +
|
635 |
+
output_numel) * comm_spec.sharding_spec.dtype_size
|
636 |
+
peak_mem = max(peak_mem, alloc_mem)
|
637 |
+
return peak_mem
|
638 |
+
|
639 |
+
def all2all_analysis(comm_spec: CommSpec, peak_mem: int):
|
640 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
641 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
642 |
+
output_numel = input_numel
|
643 |
+
comm_spec.shard_dim
|
644 |
+
alloc_mem = (input_numel +
|
645 |
+
output_numel * 3) * comm_spec.sharding_spec.dtype_size
|
646 |
+
peak_mem = max(peak_mem, alloc_mem)
|
647 |
+
return peak_mem
|
648 |
+
|
649 |
+
def peer_to_peer_analysis(comm_spec: CommSpec, peak_mem: int):
|
650 |
+
input_shape = compute_shape(comm_spec.sharding_spec)
|
651 |
+
input_numel = reduce(operator.mul, input_shape, 1)
|
652 |
+
alloc_mem = (input_numel) * comm_spec.sharding_spec.dtype_size
|
653 |
+
peak_mem = max(peak_mem, alloc_mem)
|
654 |
+
return peak_mem
|
655 |
+
|
656 |
+
pattern_to_func_dict = {
|
657 |
+
'all_gather': gather_analysis,
|
658 |
+
'all_to_all': all2all_analysis,
|
659 |
+
'split': split_analysis,
|
660 |
+
'all_reduce': reduce_analysis,
|
661 |
+
'reduce_scatter': reduce_scatter_analysis,
|
662 |
+
'peer_to_peer': peer_to_peer_analysis
|
663 |
+
}
|
664 |
+
|
665 |
+
fwd_actions = []
|
666 |
+
# construct forward and backward comm actions sequence
|
667 |
+
for comm_spec in comm_action_sequence:
|
668 |
+
fwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
|
669 |
+
fwd_actions.append(fwd_action)
|
670 |
+
|
671 |
+
# analyze memory footprint of forward comm actions sequence
|
672 |
+
fwd_peak_numel = 0
|
673 |
+
for idx, action_spec_pair in enumerate(
|
674 |
+
zip(fwd_actions, comm_action_sequence)):
|
675 |
+
# the first forward comm action will not discard input
|
676 |
+
fwd_action, comm_spec = action_spec_pair
|
677 |
+
fwd_peak_numel = fwd_action(comm_spec, fwd_peak_numel)
|
678 |
+
|
679 |
+
return fwd_peak_numel
|
680 |
+
|
681 |
+
def print_shape_consistency_result(self,
|
682 |
+
transform_path,
|
683 |
+
comm_action_sequence,
|
684 |
+
resharding_cost,
|
685 |
+
file=None):
|
686 |
+
for idx, tpath in enumerate(transform_path):
|
687 |
+
print(
|
688 |
+
f'sharding_info = [op_shape:{tpath.entire_shape}, sharding_spec:{tpath.sharding_sequence}, sharded_shape:{tpath.get_sharded_shape_per_device()}]',
|
689 |
+
end=" ",
|
690 |
+
file=file)
|
691 |
+
print('->', end=" ", file=file)
|
692 |
+
try:
|
693 |
+
commspec = comm_action_sequence[idx]
|
694 |
+
comm = [
|
695 |
+
commspec.comm_pattern, commspec.gather_dim,
|
696 |
+
commspec.shard_dim, commspec.logical_process_axis
|
697 |
+
]
|
698 |
+
except:
|
699 |
+
comm = ''
|
700 |
+
print(f'comm_info = {comm}', end=" ", file=file)
|
701 |
+
print('->', end=" ", file=file)
|
702 |
+
print(f'total_cost = {resharding_cost}', file=file)
|
703 |
+
|
704 |
+
def construct_transform_path_from_cache(self, src_spec, target_spec,
|
705 |
+
old_transform_path,
|
706 |
+
old_comm_action_sequence,
|
707 |
+
orig_cost):
|
708 |
+
new_transform_path = [src_spec]
|
709 |
+
new_comm_action_sequence = []
|
710 |
+
new_cost = orig_cost
|
711 |
+
new_src_spec = src_spec
|
712 |
+
for idx, old_comm_spec in enumerate(old_comm_action_sequence):
|
713 |
+
new_comm_spec = CommSpec(
|
714 |
+
old_comm_spec.comm_pattern,
|
715 |
+
sharding_spec=new_src_spec,
|
716 |
+
gather_dim=old_comm_spec.gather_dim,
|
717 |
+
shard_dim=old_comm_spec.shard_dim,
|
718 |
+
logical_process_axis=old_comm_spec.logical_process_axis,
|
719 |
+
forward_only=old_comm_spec.forward_only,
|
720 |
+
mix_gather=old_comm_spec.mix_gather)
|
721 |
+
new_comm_action_sequence.append(new_comm_spec)
|
722 |
+
new_cost += new_comm_spec.get_comm_cost()
|
723 |
+
old_target_spec = old_transform_path[idx + 1]
|
724 |
+
new_target_spec = ShardingSpec(new_src_spec.device_mesh,
|
725 |
+
new_src_spec.data_type_size,
|
726 |
+
new_src_spec.entire_shape,
|
727 |
+
new_src_spec.max_entire_shape,
|
728 |
+
new_src_spec.raw_shape,
|
729 |
+
old_target_spec.dim_partition_dict)
|
730 |
+
new_transform_path.append(new_target_spec)
|
731 |
+
new_src_spec = new_target_spec
|
732 |
+
assert new_transform_path[-1].get_sharded_shape_per_device(
|
733 |
+
) == target_spec.get_sharded_shape_per_device(
|
734 |
+
), 'failed to insert the cache transform path'
|
735 |
+
return new_transform_path, new_comm_action_sequence, new_cost
|
736 |
+
|
737 |
+
def shape_consistency(
|
738 |
+
self, source_spec: ShardingSpec, target_spec: ShardingSpec
|
739 |
+
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
740 |
+
'''
|
741 |
+
This method will find a path to transform source_spec to target_spec with
|
742 |
+
a greedy algorithm.
|
743 |
+
The basic idea is:
|
744 |
+
Step1:
|
745 |
+
Generate all one-step transform sequences from source_spec.
|
746 |
+
Step2:
|
747 |
+
Pick the 'best' sharding spec following the heuristic function.
|
748 |
+
Step3:
|
749 |
+
Repeat above steps until the source spec transform to target spec.
|
750 |
+
'''
|
751 |
+
MAX_TRANSFORM_STEPS = 20
|
752 |
+
total_cost = 0.0
|
753 |
+
total_steps = 0
|
754 |
+
transform_path = []
|
755 |
+
comm_action_sequence = []
|
756 |
+
# We do nothing if the sharding spec is all the same.
|
757 |
+
if source_spec.sharding_sequence_difference(target_spec) == 0:
|
758 |
+
return (transform_path, comm_action_sequence, total_cost)
|
759 |
+
|
760 |
+
spec_pairs = (str(source_spec.sharding_sequence),
|
761 |
+
str(target_spec.sharding_sequence))
|
762 |
+
|
763 |
+
if spec_pairs in self.cached_spec_pairs_transform_path:
|
764 |
+
transform_path, comm_action_sequence = self.cached_spec_pairs_transform_path[
|
765 |
+
spec_pairs]
|
766 |
+
new_transform_path, new_comm_action_sequence, new_total_cost = self.construct_transform_path_from_cache(
|
767 |
+
source_spec, target_spec, transform_path, comm_action_sequence,
|
768 |
+
total_cost)
|
769 |
+
self.cache_hit += 1
|
770 |
+
return (new_transform_path, new_comm_action_sequence,
|
771 |
+
new_total_cost)
|
772 |
+
|
773 |
+
else:
|
774 |
+
self.cache_miss += 1
|
775 |
+
|
776 |
+
temp_sharding_spec = source_spec
|
777 |
+
transform_path.append(temp_sharding_spec)
|
778 |
+
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
779 |
+
while total_steps <= MAX_TRANSFORM_STEPS:
|
780 |
+
valid_transform_spec_dict = self.get_all_one_step_transform_spec(
|
781 |
+
temp_sharding_spec, total_cost)
|
782 |
+
best_difference_score = math.inf
|
783 |
+
|
784 |
+
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
|
785 |
+
comm_spec, cost = info_pairs
|
786 |
+
spec_difference = sharding_spec.sharding_sequence_difference(
|
787 |
+
target_spec)
|
788 |
+
|
789 |
+
if spec_difference == 0:
|
790 |
+
total_cost = cost
|
791 |
+
transform_path.append(sharding_spec)
|
792 |
+
comm_action_sequence.append(comm_spec)
|
793 |
+
self.cached_spec_pairs_transform_path[spec_pairs] = (
|
794 |
+
transform_path, comm_action_sequence)
|
795 |
+
return (transform_path, comm_action_sequence, total_cost)
|
796 |
+
|
797 |
+
if spec_difference < best_difference_score:
|
798 |
+
temp_sharding_spec = sharding_spec
|
799 |
+
temp_cost = cost
|
800 |
+
temp_comm_spec = comm_spec
|
801 |
+
best_difference_score = spec_difference
|
802 |
+
|
803 |
+
transform_path.append(temp_sharding_spec)
|
804 |
+
comm_action_sequence.append(temp_comm_spec)
|
805 |
+
total_cost = temp_cost
|
806 |
+
total_steps += 1
|
807 |
+
|
808 |
+
raise RuntimeError(
|
809 |
+
f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps."
|
810 |
+
)
|
811 |
+
|
812 |
+
def dum_transform_path_from_cache(self):
|
813 |
+
src_specs, tgt_specs, path_strs = [], [], []
|
814 |
+
for spec_pairs, trans_comm_path in self.cached_spec_pairs_transform_path.items(
|
815 |
+
):
|
816 |
+
src_specs.append(spec_pairs[0])
|
817 |
+
tgt_specs.append(spec_pairs[1])
|
818 |
+
trans_paths, comm_specs = trans_comm_path[0], trans_comm_path[1]
|
819 |
+
path_str = f'{spec_pairs[0]}->'
|
820 |
+
for idx in range(1, len(trans_paths)):
|
821 |
+
comm_spec = comm_specs[idx - 1]
|
822 |
+
comm_str = f'{comm_spec.comm_pattern}: gather_dim{comm_spec.gather_dim}, shard_dim{comm_spec.shard_dim}, mesh_axis{comm_spec.logical_process_axis}->'
|
823 |
+
path_str += comm_str
|
824 |
+
path_str += f'{trans_paths[idx].sharding_sequence}->'
|
825 |
+
path_strs.append(path_str)
|
826 |
+
ret_dict = {
|
827 |
+
'src_spec': src_specs,
|
828 |
+
'dst_specs': tgt_specs,
|
829 |
+
'trans_path': path_strs
|
830 |
+
}
|
831 |
+
ret_df = pd.DataFrame.from_dict(ret_dict)
|
832 |
+
return ret_df
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_node.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Shape(Node):
|
8 |
+
|
9 |
+
def _update_memory_cost(self, strategies):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def _collect_strategies(self, device_mesh):
|
13 |
+
# one input for softmax node
|
14 |
+
predecessor = self.predecessor_nodes[0]
|
15 |
+
strategies_vector = StrategiesVector(self)
|
16 |
+
for idx, strategy in enumerate(predecessor.strategies_vector):
|
17 |
+
# current node's local name input0 -> global name xxx
|
18 |
+
global_input_name = self.op_data['input0'].name
|
19 |
+
# global name xxx -> pre node local output name
|
20 |
+
prenode_local_name = predecessor.global_to_local_op_name[
|
21 |
+
global_input_name]
|
22 |
+
dim_partition_dict = copy.deepcopy(
|
23 |
+
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
24 |
+
in0_partition_dict = dim_partition_dict
|
25 |
+
dim_partition_dict_mapping = {
|
26 |
+
"input0": in0_partition_dict,
|
27 |
+
"output0": {},
|
28 |
+
}
|
29 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
30 |
+
dim_partition_dict_mapping, device_mesh)
|
31 |
+
if 0 == len(sharding_spec_mapping):
|
32 |
+
return strategies_vector
|
33 |
+
name = '{} = <shape> {}'.format(
|
34 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
35 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
36 |
+
sharding_strategy = self._get_sharding_strategy(
|
37 |
+
name=name,
|
38 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
39 |
+
communication_action_mapping={})
|
40 |
+
strategies_vector.append(sharding_strategy)
|
41 |
+
return strategies_vector
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_spec.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
from functools import reduce
|
3 |
+
|
4 |
+
import tensorrt as trt
|
5 |
+
|
6 |
+
from tensorrt_llm.logger import logger
|
7 |
+
|
8 |
+
ALLGATHER_COST = 20
|
9 |
+
SHARD_COST = 5
|
10 |
+
STEP_PENALTY = 6
|
11 |
+
NAN = 'nan'
|
12 |
+
|
13 |
+
|
14 |
+
def _convert_str_to_shard_list(str_spec):
|
15 |
+
'''
|
16 |
+
Convert str_spec into shard_list.
|
17 |
+
|
18 |
+
Argument:
|
19 |
+
str_spec(str): dim spec in str type.
|
20 |
+
'''
|
21 |
+
|
22 |
+
if str_spec == 'R':
|
23 |
+
return []
|
24 |
+
if str_spec == 'S0':
|
25 |
+
return [0]
|
26 |
+
if str_spec == 'S1':
|
27 |
+
return [1]
|
28 |
+
if str_spec == 'S01':
|
29 |
+
return [0, 1]
|
30 |
+
|
31 |
+
|
32 |
+
def _build_difference_2d_dict():
|
33 |
+
'''
|
34 |
+
Build a difference mapping for 2D device mesh case. It will be used to
|
35 |
+
compute the difference between DimSpec pairs.
|
36 |
+
'''
|
37 |
+
|
38 |
+
source_spec_list = ['R', 'S0', 'S1', 'S01']
|
39 |
+
target_spec_list = ['R', 'S0', 'S1', 'S01']
|
40 |
+
difference_dict = {}
|
41 |
+
for source_spec in source_spec_list:
|
42 |
+
for target_spec in target_spec_list:
|
43 |
+
spec_pair = (source_spec, target_spec)
|
44 |
+
source_shard_list = _convert_str_to_shard_list(source_spec)
|
45 |
+
target_shard_list = _convert_str_to_shard_list(target_spec)
|
46 |
+
|
47 |
+
# source same as target
|
48 |
+
if source_shard_list == target_shard_list:
|
49 |
+
difference = 0
|
50 |
+
|
51 |
+
# all_gather(source) -> target
|
52 |
+
elif len(source_shard_list) == len(
|
53 |
+
target_shard_list
|
54 |
+
) + 1 and source_shard_list[:-1] == target_shard_list:
|
55 |
+
difference = ALLGATHER_COST
|
56 |
+
|
57 |
+
# shard(source) -> target
|
58 |
+
elif len(source_shard_list) == len(
|
59 |
+
target_shard_list
|
60 |
+
) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
|
61 |
+
-1] not in source_shard_list:
|
62 |
+
difference = SHARD_COST
|
63 |
+
|
64 |
+
# S1 -> S0 or S0 -> S1
|
65 |
+
elif len(source_shard_list) == len(target_shard_list):
|
66 |
+
# source -> R -> target
|
67 |
+
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
|
68 |
+
|
69 |
+
# R -> S01
|
70 |
+
elif len(source_shard_list) == len(target_shard_list) - 2:
|
71 |
+
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
|
72 |
+
|
73 |
+
# S01 -> R
|
74 |
+
elif len(source_shard_list) == len(target_shard_list) + 2:
|
75 |
+
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
|
76 |
+
|
77 |
+
# S1 -> S01
|
78 |
+
elif len(source_shard_list) == len(target_shard_list) - 1:
|
79 |
+
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
|
80 |
+
|
81 |
+
# S01 -> S1
|
82 |
+
elif len(source_shard_list) == len(target_shard_list) + 1:
|
83 |
+
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
|
84 |
+
|
85 |
+
else:
|
86 |
+
difference = NAN
|
87 |
+
difference_dict[spec_pair] = difference
|
88 |
+
|
89 |
+
return difference_dict
|
90 |
+
|
91 |
+
|
92 |
+
_difference_dict = _build_difference_2d_dict()
|
93 |
+
|
94 |
+
|
95 |
+
class DimSpec:
|
96 |
+
'''
|
97 |
+
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
|
98 |
+
logical device mesh and give a method to compute the difference between them.
|
99 |
+
|
100 |
+
Argument:
|
101 |
+
shard_list(List[int]): if shard_list is empty, the dim spec will be 'R' type.
|
102 |
+
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
103 |
+
'''
|
104 |
+
|
105 |
+
def __init__(self, shard_list):
|
106 |
+
self.is_replica = len(shard_list) == 0
|
107 |
+
self.shard_list = shard_list
|
108 |
+
|
109 |
+
def __eq__(self, other):
|
110 |
+
return str(self) == str(other)
|
111 |
+
|
112 |
+
def __repr__(self):
|
113 |
+
if self.is_replica:
|
114 |
+
return 'R'
|
115 |
+
target = 'S'
|
116 |
+
for dim in self.shard_list:
|
117 |
+
target += str(dim)
|
118 |
+
return target
|
119 |
+
|
120 |
+
def difference(self, other):
|
121 |
+
'''
|
122 |
+
The difference between two DimSpec.
|
123 |
+
|
124 |
+
Argument:
|
125 |
+
other(DimSpec): the dim spec to compare with.
|
126 |
+
|
127 |
+
Return:
|
128 |
+
difference(int): the difference between two DimSpec.
|
129 |
+
|
130 |
+
Example:
|
131 |
+
dim_spec = DimSpec([0])
|
132 |
+
other_dim_spec = DimSpec([0, 1])
|
133 |
+
print(dim_spec.difference(other_dim_spec))
|
134 |
+
|
135 |
+
Output:
|
136 |
+
5
|
137 |
+
'''
|
138 |
+
difference = _difference_dict[(str(self), str(other))]
|
139 |
+
return difference
|
140 |
+
|
141 |
+
|
142 |
+
def get_sharding_sequence(num_dims, dims, device_dims):
|
143 |
+
sharding_sequence = [DimSpec([])] * num_dims
|
144 |
+
for dim, shard_list in zip(dims, device_dims):
|
145 |
+
sharding_sequence[dim] = DimSpec(shard_list)
|
146 |
+
return sharding_sequence
|
147 |
+
|
148 |
+
|
149 |
+
class ShardingSpec:
|
150 |
+
'''
|
151 |
+
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
152 |
+
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
153 |
+
[R, R, S0, S1].
|
154 |
+
|
155 |
+
Argument:
|
156 |
+
device_mesh: A logical view of a physical mesh.
|
157 |
+
entire_shape: The entire shape of tensor before sharded.
|
158 |
+
dim_partition_dict: The key is the dimension of tensor to be sharded,
|
159 |
+
and the value of the key describe which logical axis will be sharded in that dimension.
|
160 |
+
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
161 |
+
'''
|
162 |
+
|
163 |
+
def __init__(self,
|
164 |
+
device_mesh,
|
165 |
+
data_type_size,
|
166 |
+
data_shape,
|
167 |
+
max_data_shape,
|
168 |
+
raw_data_shape,
|
169 |
+
dim_partition_dict=None,
|
170 |
+
sharding_sequence=None):
|
171 |
+
self.device_mesh = device_mesh
|
172 |
+
self.data_type_size = data_type_size
|
173 |
+
self.dtype = data_type_size[0]
|
174 |
+
self.dtype_size = data_type_size[1]
|
175 |
+
self.entire_shape = data_shape
|
176 |
+
self.max_entire_shape = max_data_shape
|
177 |
+
self.raw_shape = raw_data_shape
|
178 |
+
self.dim_partition_dict = dim_partition_dict
|
179 |
+
self.sharding_sequence = sharding_sequence
|
180 |
+
self.enable_shard_unbalanced_shape = device_mesh.config.enable_shard_unbalanced_shape
|
181 |
+
self.enable_shard_dynamic_shape = device_mesh.config.enable_shard_dynamic_shape
|
182 |
+
if self.sharding_sequence is None:
|
183 |
+
self.dim_partition_dict = self._merge_same_dim_mesh_list(
|
184 |
+
len(self.entire_shape), self.dim_partition_dict)
|
185 |
+
self.dim_partition_dict = self._convert_dim_partition_dict(
|
186 |
+
len(self.entire_shape), self.dim_partition_dict)
|
187 |
+
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
188 |
+
self.convert_dict_to_shard_sequence()
|
189 |
+
elif self.dim_partition_dict is None:
|
190 |
+
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
191 |
+
self.convert_shard_sequence_to_dict()
|
192 |
+
self.dim_partition_dict = self._merge_same_dim_mesh_list(
|
193 |
+
len(self.entire_shape), self.dim_partition_dict)
|
194 |
+
self.dim_partition_dict = self._convert_dim_partition_dict(
|
195 |
+
len(self.entire_shape), self.dim_partition_dict)
|
196 |
+
|
197 |
+
self.sharded_shape, self.max_sharded_shape = [*self.entire_shape], [
|
198 |
+
*self.max_entire_shape
|
199 |
+
]
|
200 |
+
for dim, shard_list in self.dim_partition_dict.items():
|
201 |
+
mesh_list = [
|
202 |
+
self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list
|
203 |
+
]
|
204 |
+
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
205 |
+
self.sharded_shape[dim] = (self.sharded_shape[dim] +
|
206 |
+
shard_partitions - 1) // shard_partitions
|
207 |
+
self.max_sharded_shape[dim] = (self.max_sharded_shape[dim] +
|
208 |
+
shard_partitions -
|
209 |
+
1) // shard_partitions
|
210 |
+
|
211 |
+
def print_spec(self, file=None):
|
212 |
+
print(
|
213 |
+
f"sharding_sequence = {self.sharding_sequence}, shape = {self.get_sharded_shape_per_device()}",
|
214 |
+
file=file,
|
215 |
+
)
|
216 |
+
|
217 |
+
def _merge_same_dim_mesh_list(self, dim_size, dim_partition_dict):
|
218 |
+
'''
|
219 |
+
This method is used to merge the different key value which points to same physical position.
|
220 |
+
|
221 |
+
For example:
|
222 |
+
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
223 |
+
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
224 |
+
'''
|
225 |
+
converted_dim_partition_dict = {}
|
226 |
+
for dim, mesh_list in dim_partition_dict.items():
|
227 |
+
if dim < 0:
|
228 |
+
dim = dim_size + dim
|
229 |
+
if dim not in converted_dim_partition_dict:
|
230 |
+
converted_dim_partition_dict[dim] = mesh_list
|
231 |
+
else:
|
232 |
+
converted_dim_partition_dict[dim].extend(mesh_list)
|
233 |
+
converted_dim_partition_dict[dim].sort()
|
234 |
+
return converted_dim_partition_dict
|
235 |
+
|
236 |
+
def _convert_dim_partition_dict(self, dim_size, dim_partition_dict):
|
237 |
+
dims_to_convert = []
|
238 |
+
for dim, mesh_list in dim_partition_dict.items():
|
239 |
+
if dim < 0:
|
240 |
+
dims_to_convert.append(dim)
|
241 |
+
for dim in dims_to_convert:
|
242 |
+
dim_partition_dict.pop(dim)
|
243 |
+
dim_partition_dict[dim_size + dim] = mesh_list
|
244 |
+
return dim_partition_dict
|
245 |
+
|
246 |
+
def _remove_mesh_dim_one(self, dim_partition_dict):
|
247 |
+
dims_to_remove = []
|
248 |
+
for dim, mesh_list in dim_partition_dict.items():
|
249 |
+
new_mesh_list = []
|
250 |
+
for mesh_dim in mesh_list:
|
251 |
+
if self.device_mesh.mesh_shape[mesh_dim] != 1:
|
252 |
+
new_mesh_list.append(mesh_dim)
|
253 |
+
if 0 != len(new_mesh_list):
|
254 |
+
dim_partition_dict[dim] = new_mesh_list
|
255 |
+
else:
|
256 |
+
dims_to_remove.append(dim)
|
257 |
+
for dim in dims_to_remove:
|
258 |
+
dim_partition_dict.pop(dim)
|
259 |
+
return dim_partition_dict
|
260 |
+
|
261 |
+
def __repr__(self):
|
262 |
+
res = "DistSpec("
|
263 |
+
res += f"shard_sequence={self.sharding_sequence},"
|
264 |
+
res += f"shape={self.device_mesh.mesh_shape}"
|
265 |
+
res += ")"
|
266 |
+
return res
|
267 |
+
|
268 |
+
def sanity_check(self):
|
269 |
+
# make sure all axes in logical device mesh only be used once
|
270 |
+
dim_check_list = [*range(len(self.device_mesh.mesh_shape))]
|
271 |
+
for dim, shard_list in self.dim_partition_dict.items():
|
272 |
+
for element in shard_list:
|
273 |
+
if element in dim_check_list:
|
274 |
+
dim_check_list.remove(element)
|
275 |
+
else:
|
276 |
+
logger.warning(
|
277 |
+
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}. dim_partition_dict={self.dim_partition_dict}"
|
278 |
+
)
|
279 |
+
return False
|
280 |
+
|
281 |
+
# make sure that the dimension is not out of index
|
282 |
+
for dim in self.dim_partition_dict.keys():
|
283 |
+
# we have tried to convert the negative value to positive value, if it is larger than the dim_size or negative still, it is out of index
|
284 |
+
if dim >= len(self.entire_shape) or dim < 0:
|
285 |
+
print(
|
286 |
+
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
|
287 |
+
)
|
288 |
+
return False
|
289 |
+
|
290 |
+
if not self.enable_shard_dynamic_shape:
|
291 |
+
# make sure to not to shard on dynamic shape
|
292 |
+
for dim, shard_list in self.dim_partition_dict.items():
|
293 |
+
if len(shard_list) == 0:
|
294 |
+
continue
|
295 |
+
if len(self.raw_shape) == 0:
|
296 |
+
continue
|
297 |
+
if -1 == self.raw_shape[dim]:
|
298 |
+
return False
|
299 |
+
|
300 |
+
# make sure that the sharding for a dimension is divisible by the number of devices
|
301 |
+
for dim, shard_list in self.dim_partition_dict.items():
|
302 |
+
if len(shard_list) == 0:
|
303 |
+
continue
|
304 |
+
tensor_dim_size = self.entire_shape[dim]
|
305 |
+
num_devices = 1
|
306 |
+
|
307 |
+
for element in shard_list:
|
308 |
+
num_devices *= self.device_mesh.mesh_shape[element]
|
309 |
+
if num_devices == 1:
|
310 |
+
# we only support RR when the device is 1
|
311 |
+
return False
|
312 |
+
|
313 |
+
if not self.enable_shard_unbalanced_shape:
|
314 |
+
if tensor_dim_size % num_devices != 0 or tensor_dim_size == 1:
|
315 |
+
'''
|
316 |
+
print(
|
317 |
+
f'The size of static dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
|
318 |
+
)
|
319 |
+
'''
|
320 |
+
return False
|
321 |
+
else:
|
322 |
+
if tensor_dim_size == 1:
|
323 |
+
return False
|
324 |
+
'''
|
325 |
+
if self.get_sharded_size_per_device() > (2**31 - 1):
|
326 |
+
print(
|
327 |
+
f'memory footprint per device {self.get_sharded_size_per_device()} is larger than 2**31 - 1'
|
328 |
+
)
|
329 |
+
return False
|
330 |
+
'''
|
331 |
+
return True
|
332 |
+
|
333 |
+
def convert_dict_to_shard_sequence(self):
|
334 |
+
'''
|
335 |
+
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
|
336 |
+
'''
|
337 |
+
sharding_sequence = [DimSpec([])] * len(self.entire_shape)
|
338 |
+
for dim, shard_list in self.dim_partition_dict.items():
|
339 |
+
sharding_sequence[dim] = DimSpec(shard_list)
|
340 |
+
self.sharding_sequence = sharding_sequence
|
341 |
+
|
342 |
+
def convert_shard_sequence_to_dict(self):
|
343 |
+
'''
|
344 |
+
Convert sharding_sequence into dim_partition_dict.
|
345 |
+
'''
|
346 |
+
new_dim_partition_dict = {}
|
347 |
+
for index, dim_spec in enumerate(self.sharding_sequence):
|
348 |
+
if not dim_spec.is_replica:
|
349 |
+
if index not in new_dim_partition_dict:
|
350 |
+
new_dim_partition_dict[index] = []
|
351 |
+
new_dim_partition_dict[index].extend(dim_spec.shard_list)
|
352 |
+
self.dim_partition_dict = new_dim_partition_dict
|
353 |
+
|
354 |
+
def sharding_sequence_difference(self, other):
|
355 |
+
'''
|
356 |
+
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
|
357 |
+
pair of sharding sequence.
|
358 |
+
|
359 |
+
Example:
|
360 |
+
dim_partition_dict = {0: [0, 1]}
|
361 |
+
# DistSpec:
|
362 |
+
# shard_sequence: S01,R,R
|
363 |
+
# device_mesh_shape: (4, 4)
|
364 |
+
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
365 |
+
dim_partition_dict_to_compare = {0: [0], 1: [1]}
|
366 |
+
# DistSpec:
|
367 |
+
# shard_sequence: S0,S1,R
|
368 |
+
# device_mesh_shape: (4, 4)
|
369 |
+
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
370 |
+
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
371 |
+
|
372 |
+
Output:
|
373 |
+
25
|
374 |
+
|
375 |
+
Argument:
|
376 |
+
other(ShardingSpec): The ShardingSpec to compared with.
|
377 |
+
|
378 |
+
Return:
|
379 |
+
difference(int): Difference between two ShardingSpec.
|
380 |
+
'''
|
381 |
+
assert len(self.sharding_sequence) == len(
|
382 |
+
other.sharding_sequence
|
383 |
+
), f'Cannot compare difference for two sharding specs with different length.'
|
384 |
+
difference = 0
|
385 |
+
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence,
|
386 |
+
other.sharding_sequence):
|
387 |
+
difference += orig_dim_spec.difference(other_dim_spec)
|
388 |
+
return difference
|
389 |
+
|
390 |
+
def get_sharded_shape_per_device(self, ):
|
391 |
+
return self.sharded_shape
|
392 |
+
|
393 |
+
def get_sharded_element_per_device(self, ):
|
394 |
+
sharded_shape = self.get_sharded_shape_per_device()
|
395 |
+
if len(sharded_shape) == 0:
|
396 |
+
num_elements = 1
|
397 |
+
else:
|
398 |
+
num_elements = trt.volume(sharded_shape)
|
399 |
+
return num_elements
|
400 |
+
|
401 |
+
def get_sharded_size_per_device(self, ):
|
402 |
+
num_elements = self.get_sharded_element_per_device()
|
403 |
+
return num_elements * self.dtype_size
|
404 |
+
|
405 |
+
def get_max_sharded_shape_per_device(self, ):
|
406 |
+
return self.max_sharded_shape
|
407 |
+
|
408 |
+
def get_max_sharded_element_per_device(self, ):
|
409 |
+
max_sharded_shape = self.get_max_sharded_shape_per_device()
|
410 |
+
if len(max_sharded_shape) == 0:
|
411 |
+
num_elements = 1
|
412 |
+
else:
|
413 |
+
num_elements = trt.volume(max_sharded_shape)
|
414 |
+
return num_elements
|
415 |
+
|
416 |
+
def get_max_sharded_size_per_device(self, ):
|
417 |
+
num_elements = self.get_max_sharded_element_per_device()
|
418 |
+
return num_elements * self.dtype_size
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_strategy.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ShardingStrategy(object):
|
2 |
+
|
3 |
+
def __init__(self,
|
4 |
+
name=None,
|
5 |
+
sharding_specs=None,
|
6 |
+
communication_actions=None):
|
7 |
+
self.name = name or ""
|
8 |
+
self.sharding_specs = sharding_specs or {}
|
9 |
+
self.communication_actions = communication_actions
|
10 |
+
self.sharding_cost = 0
|
11 |
+
self.communication_cost = 0
|
12 |
+
self.resharding_costs = {}
|
13 |
+
self.best_resharding_cost = {}
|
14 |
+
self.node_names = {}
|
15 |
+
|
16 |
+
self.comm_buff_memory_footprint = 0
|
17 |
+
self.inout_memory_footprint = 0
|
18 |
+
self.const_memory_footprint = 0
|
19 |
+
self.peak_memory_footprint = 0
|
20 |
+
self.computation_macs = 0
|
21 |
+
self.alpha_beta_cost = 0
|
22 |
+
|
23 |
+
def print_strategy(self, best_resharding_cost_only=False, file=None):
|
24 |
+
|
25 |
+
def print_resharding_costs(resharding_cost):
|
26 |
+
for prenode_node_name, rcosts in resharding_cost.items():
|
27 |
+
if isinstance(prenode_node_name, int):
|
28 |
+
idx = prenode_node_name
|
29 |
+
prenode_node_name = self.node_names[idx]
|
30 |
+
print(f' pre_node = {idx} {prenode_node_name}',
|
31 |
+
file=file)
|
32 |
+
else:
|
33 |
+
print(f' pre_node = {prenode_node_name}', file=file)
|
34 |
+
for idx, rcost in enumerate(rcosts):
|
35 |
+
transpaths, commspecs, cost = rcost
|
36 |
+
print(f' {idx}: ', end=' ', file=file)
|
37 |
+
device_mesh.shape_consistency_manager.print_shape_consistency_result(
|
38 |
+
transpaths, commspecs, cost, file)
|
39 |
+
|
40 |
+
print(f'name = {self.name}', file=file)
|
41 |
+
print(f'sharding_cost = {self.sharding_cost}', file=file)
|
42 |
+
print(
|
43 |
+
f'communication_buffer_memory_footprint = {self.comm_buff_memory_footprint}, communication_cost = {self.communication_cost}',
|
44 |
+
file=file)
|
45 |
+
print(f'inout_memory_footprint = {self.inout_memory_footprint}',
|
46 |
+
file=file)
|
47 |
+
print(f'peak_memory_footprint = {self.peak_memory_footprint}',
|
48 |
+
file=file)
|
49 |
+
print(f'const_memory_footprint = {self.const_memory_footprint}',
|
50 |
+
file=file)
|
51 |
+
print('sharding_specs:', file=file)
|
52 |
+
device_mesh = None
|
53 |
+
for specname, spec in self.sharding_specs.items():
|
54 |
+
print(specname + ', ', end=' ', file=file)
|
55 |
+
spec.print_spec(file)
|
56 |
+
device_mesh = spec.device_mesh
|
57 |
+
|
58 |
+
if best_resharding_cost_only and self.best_resharding_cost:
|
59 |
+
print('best_resharding_costs:', file=file)
|
60 |
+
print_resharding_costs(self.best_resharding_cost)
|
61 |
+
else:
|
62 |
+
print('resharding costs:', file=file)
|
63 |
+
print_resharding_costs(self.resharding_costs)
|
64 |
+
|
65 |
+
|
66 |
+
class StrategiesVector(list):
|
67 |
+
'''
|
68 |
+
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
69 |
+
strategies of the node.
|
70 |
+
|
71 |
+
Argument:
|
72 |
+
node (Node): node for which the list of sharding strategies are generated.
|
73 |
+
'''
|
74 |
+
|
75 |
+
def __init__(self, node):
|
76 |
+
super().__init__()
|
77 |
+
self.node = node
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shuffle_node.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Shuffle(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
self.first_tanspose_dims = layer.as_trt().first_transpose
|
13 |
+
self.second_transpose_dims = layer.as_trt().second_transpose
|
14 |
+
self.zero_is_placeholder = layer.as_trt().zero_is_placeholder
|
15 |
+
self.is_first_transepose_identity = (sorted(
|
16 |
+
self.first_tanspose_dims) == list(self.first_tanspose_dims))
|
17 |
+
self.input_shape = self.get_input(0).shape
|
18 |
+
self.is_second_transepose_identity = (sorted(
|
19 |
+
self.second_transpose_dims) == list(self.second_transpose_dims))
|
20 |
+
|
21 |
+
output_shape = list(self.get_output(0).shape)
|
22 |
+
self.reshape_dims = copy.deepcopy(output_shape)
|
23 |
+
if not self.is_second_transepose_identity:
|
24 |
+
for i in self.second_transpose_dims:
|
25 |
+
if self.second_transpose_dims[i] != i:
|
26 |
+
self.reshape_dims[
|
27 |
+
self.second_transpose_dims[i]] = output_shape[i]
|
28 |
+
self.is_reshape_identity = (list(self.reshape_dims) == list(
|
29 |
+
self.input_shape))
|
30 |
+
layer.to_base_class()
|
31 |
+
|
32 |
+
def _collect_transpose_strategies(self, device_mesh, transpose_dims):
|
33 |
+
dim_partition_list = []
|
34 |
+
dim_size = len(self.op_data['input0'].shape)
|
35 |
+
dim_partition_list.append({})
|
36 |
+
dim_partition_list.extend(
|
37 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
38 |
+
dim_partition_list.extend(
|
39 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
40 |
+
dim_partition_list.extend(
|
41 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
42 |
+
dim_partition_list.extend(
|
43 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
44 |
+
strategies_vector = StrategiesVector(self)
|
45 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
46 |
+
for dim_partition_dict in dim_partition_list:
|
47 |
+
in0_partition_dict = dim_partition_dict
|
48 |
+
out_partition_dict = {}
|
49 |
+
for split_dim, mesh_dim in in0_partition_dict.items():
|
50 |
+
trans_dim = transpose_dims[split_dim]
|
51 |
+
out_partition_dict[trans_dim] = mesh_dim
|
52 |
+
|
53 |
+
dim_partition_dict_mapping = {
|
54 |
+
"input0": in0_partition_dict,
|
55 |
+
"output0": out_partition_dict,
|
56 |
+
}
|
57 |
+
if self.num_inputs == 2:
|
58 |
+
dim_partition_dict_mapping["input1"] = {}
|
59 |
+
|
60 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
61 |
+
dim_partition_dict_mapping, device_mesh)
|
62 |
+
if 0 == len(sharding_spec_mapping):
|
63 |
+
continue
|
64 |
+
name = '{} = <shuffle_transpose_only op> {}'.format(
|
65 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
66 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
67 |
+
sharding_strategy = self._get_sharding_strategy(
|
68 |
+
name=name,
|
69 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
70 |
+
communication_action_mapping={})
|
71 |
+
strategies_vector.append(sharding_strategy)
|
72 |
+
return strategies_vector
|
73 |
+
|
74 |
+
def _find_reshape_partitions(self, input_shape, output_shape,
|
75 |
+
input_partition_dict):
|
76 |
+
len_input_shape, len_output_shape = len(input_shape), len(output_shape)
|
77 |
+
output_partition_dict = {}
|
78 |
+
i, j = 0, 0
|
79 |
+
while i < len_input_shape or j < len_output_shape:
|
80 |
+
if i < len_input_shape and input_shape[i] == 1:
|
81 |
+
i = i + 1
|
82 |
+
continue
|
83 |
+
if j < len_output_shape and output_shape[j] == 1:
|
84 |
+
j = j + 1
|
85 |
+
continue
|
86 |
+
|
87 |
+
if input_shape[i] == output_shape[j]:
|
88 |
+
if i in input_partition_dict:
|
89 |
+
output_partition_dict[j] = input_partition_dict[i]
|
90 |
+
# it keep the dimension, so need to keep the partition dims
|
91 |
+
i, j = i + 1, j + 1
|
92 |
+
|
93 |
+
elif input_shape[i] < output_shape[j]:
|
94 |
+
# we detect if the input dims are merged in the reshape dim
|
95 |
+
value = input_shape[i]
|
96 |
+
for ii in range(i + 1, len_input_shape):
|
97 |
+
value = value * input_shape[ii]
|
98 |
+
if value == output_shape[j]:
|
99 |
+
# it is merged, we set the output's merged dim partition as all inputs' dims
|
100 |
+
mesh_dim = []
|
101 |
+
for in_dim in range(i, ii + 1):
|
102 |
+
if in_dim in input_partition_dict:
|
103 |
+
mesh_dim = mesh_dim + input_partition_dict[
|
104 |
+
in_dim]
|
105 |
+
if len(mesh_dim) > 0:
|
106 |
+
output_partition_dict[j] = sorted(mesh_dim)
|
107 |
+
i, j = ii + 1, j + 1
|
108 |
+
break
|
109 |
+
else:
|
110 |
+
# we don't find the merged dimensions, the difference may from random reshape, we don't support it now
|
111 |
+
return {}, {}
|
112 |
+
else:
|
113 |
+
# we detect if the input dim is split into reshape dims
|
114 |
+
value = output_shape[j]
|
115 |
+
for jj in range(j + 1, len_output_shape):
|
116 |
+
value = value * output_shape[jj]
|
117 |
+
if value == input_shape[i]:
|
118 |
+
# it is split pattern
|
119 |
+
if i in input_partition_dict:
|
120 |
+
output_partition_dict[j] = input_partition_dict[i]
|
121 |
+
i, j = i + 1, jj + 1
|
122 |
+
break
|
123 |
+
else:
|
124 |
+
# we don't find the split dimensions, the difference may from random reshape
|
125 |
+
return {}, {}
|
126 |
+
return input_partition_dict, output_partition_dict
|
127 |
+
|
128 |
+
def _collect_reshape_strategies(self, device_mesh):
|
129 |
+
dim_partition_list = []
|
130 |
+
dim_size = len(self.op_data['input0'].shape)
|
131 |
+
dim_partition_list.append({})
|
132 |
+
dim_partition_list.extend(
|
133 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
134 |
+
dim_partition_list.extend(
|
135 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
136 |
+
dim_partition_list.extend(
|
137 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
138 |
+
dim_partition_list.extend(
|
139 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
140 |
+
strategies_vector = StrategiesVector(self)
|
141 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
142 |
+
for dim_partition_dict in dim_partition_list:
|
143 |
+
in0_partition_dict = dim_partition_dict
|
144 |
+
in0_partition_dict, out_partition_dict = self._find_reshape_partitions(
|
145 |
+
self.input_shape, self.reshape_dims, in0_partition_dict)
|
146 |
+
dim_partition_dict_mapping = {
|
147 |
+
"input0": in0_partition_dict,
|
148 |
+
"output0": out_partition_dict,
|
149 |
+
}
|
150 |
+
if self.num_inputs == 2:
|
151 |
+
dim_partition_dict_mapping["input1"] = {}
|
152 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
153 |
+
dim_partition_dict_mapping, device_mesh)
|
154 |
+
if 0 == len(sharding_spec_mapping):
|
155 |
+
continue
|
156 |
+
name = '{} = <shuffle_reshape op> {}'.format(
|
157 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
158 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
159 |
+
sharding_strategy = self._get_sharding_strategy(
|
160 |
+
name=name,
|
161 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
162 |
+
communication_action_mapping={})
|
163 |
+
strategies_vector.append(sharding_strategy)
|
164 |
+
return strategies_vector
|
165 |
+
|
166 |
+
def _collect_identity_strategies(self, device_mesh):
|
167 |
+
dim_partition_list = []
|
168 |
+
dim_size = len(self.op_data['input0'].shape)
|
169 |
+
dim_partition_list.append({})
|
170 |
+
dim_partition_list.extend(
|
171 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
172 |
+
dim_partition_list.extend(
|
173 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
174 |
+
dim_partition_list.extend(
|
175 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
176 |
+
dim_partition_list.extend(
|
177 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
178 |
+
strategies_vector = StrategiesVector(self)
|
179 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
180 |
+
for dim_partition_dict in dim_partition_list:
|
181 |
+
in0_partition_dict = dim_partition_dict
|
182 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
183 |
+
dim_partition_dict_mapping = {
|
184 |
+
"input0": in0_partition_dict,
|
185 |
+
"output0": out_partition_dict,
|
186 |
+
}
|
187 |
+
if self.num_inputs == 2:
|
188 |
+
dim_partition_dict_mapping["input1"] = {}
|
189 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
190 |
+
dim_partition_dict_mapping, device_mesh)
|
191 |
+
if 0 == len(sharding_spec_mapping):
|
192 |
+
continue
|
193 |
+
name = '{} = <shuffle_identity op> {}'.format(
|
194 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
195 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
196 |
+
sharding_strategy = self._get_sharding_strategy(
|
197 |
+
name=name,
|
198 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
199 |
+
communication_action_mapping={})
|
200 |
+
strategies_vector.append(sharding_strategy)
|
201 |
+
return strategies_vector
|
202 |
+
|
203 |
+
def _collect_strategies(self, device_mesh):
|
204 |
+
is_identify_list = (self.is_first_transepose_identity,
|
205 |
+
self.is_reshape_identity,
|
206 |
+
self.is_second_transepose_identity)
|
207 |
+
if is_identify_list == (True, True, True):
|
208 |
+
return self._collect_identity_strategies(device_mesh)
|
209 |
+
elif is_identify_list == (True, True, False):
|
210 |
+
return self._collect_transpose_strategies(
|
211 |
+
device_mesh, self.second_transpose_dims)
|
212 |
+
elif is_identify_list == (False, True, True):
|
213 |
+
return self._collect_transpose_strategies(device_mesh,
|
214 |
+
self.first_transpose_dims)
|
215 |
+
elif is_identify_list == (True, False, True):
|
216 |
+
return self._collect_reshape_strategies(device_mesh)
|
217 |
+
else:
|
218 |
+
assert False, f"Unsupported shuffle pattern now {is_identify_list}"
|
219 |
+
|
220 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
221 |
+
updated_layer_attrs = {}
|
222 |
+
updated_input_values = {}
|
223 |
+
output_shape = strategy.sharding_specs[
|
224 |
+
'output0'].get_sharded_shape_per_device()
|
225 |
+
self.layer.to_subclass()
|
226 |
+
second_transpose = self.layer.as_trt().second_transpose
|
227 |
+
self.layer.to_base_class()
|
228 |
+
reshape_dims = [*output_shape]
|
229 |
+
for i in range(len(output_shape)):
|
230 |
+
reshape_dims[second_transpose[i]] = output_shape[i]
|
231 |
+
if self.layer.num_inputs >= 2:
|
232 |
+
updated_input_values[1] = reshape_dims
|
233 |
+
else:
|
234 |
+
updated_layer_attrs['reshape_dims'] = reshape_dims
|
235 |
+
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
236 |
+
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
237 |
+
device_mesh)
|
238 |
+
return elapsed_time
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/slice_node.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from .node import Node
|
4 |
+
from .sharding_strategy import StrategiesVector
|
5 |
+
|
6 |
+
|
7 |
+
class Slice(Node):
|
8 |
+
|
9 |
+
def __init__(self, layer):
|
10 |
+
super().__init__(layer)
|
11 |
+
layer.to_subclass()
|
12 |
+
input_shape = self.get_input(0).shape
|
13 |
+
output_shape = self.get_output(0).shape
|
14 |
+
assert len(input_shape) == len(
|
15 |
+
output_shape
|
16 |
+
), f'dims of input shape {input_shape} != dims of output shape {output_shape}'
|
17 |
+
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
18 |
+
start = layer.get_input(1).value
|
19 |
+
else:
|
20 |
+
start = layer.as_trt().start
|
21 |
+
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
|
22 |
+
stride = layer.get_input(3).value
|
23 |
+
else:
|
24 |
+
stride = layer.as_trt().stride
|
25 |
+
self.keep_partition_dims = [(input_shape[i] == output_shape[i]
|
26 |
+
and start[i] == 0 and stride[i] == 1)
|
27 |
+
for i in range(len(input_shape))]
|
28 |
+
layer.to_base_class()
|
29 |
+
|
30 |
+
def _update_memory_cost(self, strategies):
|
31 |
+
for strategy in strategies:
|
32 |
+
# for slice node, it input0's read = output0's write
|
33 |
+
inout_memory_footprint = strategy.sharding_specs[
|
34 |
+
'output0'].get_sharded_size_per_device() * 2
|
35 |
+
strategy.inout_memory_footprint = inout_memory_footprint
|
36 |
+
strategy.peak_memory_footprint = (
|
37 |
+
strategy.sharding_specs['input0'].
|
38 |
+
get_max_sharded_size_per_device() + strategy.
|
39 |
+
sharding_specs['output0'].get_max_sharded_size_per_device())
|
40 |
+
|
41 |
+
def _collect_strategies(self, device_mesh):
|
42 |
+
dim_partition_list = []
|
43 |
+
dim_size = len(self.op_data['input0'].shape)
|
44 |
+
dim_partition_list.append({})
|
45 |
+
dim_partition_list.extend(
|
46 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
47 |
+
dim_partition_list.extend(
|
48 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
49 |
+
dim_partition_list.extend(
|
50 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
51 |
+
dim_partition_list.extend(
|
52 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
53 |
+
strategies_vector = StrategiesVector(self)
|
54 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
55 |
+
for dim_partition_dict in dim_partition_list:
|
56 |
+
for dim in range(len(self.keep_partition_dims)):
|
57 |
+
if (not self.keep_partition_dims[dim]
|
58 |
+
) and dim in dim_partition_dict:
|
59 |
+
dim_partition_dict.pop(dim)
|
60 |
+
|
61 |
+
in0_partition_dict = dim_partition_dict
|
62 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
63 |
+
dim_partition_dict_mapping = {
|
64 |
+
"input0": in0_partition_dict,
|
65 |
+
"output0": out_partition_dict,
|
66 |
+
}
|
67 |
+
for i in range(1, self.num_inputs):
|
68 |
+
if self.predecessor_nodes[i]:
|
69 |
+
dim_partition_dict_mapping[f"input{i}"] = {}
|
70 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
71 |
+
dim_partition_dict_mapping, device_mesh)
|
72 |
+
if 0 == len(sharding_spec_mapping):
|
73 |
+
continue
|
74 |
+
name = '{} = {} <slice op> '.format(
|
75 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
76 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
77 |
+
for i in range(1, self.num_inputs):
|
78 |
+
if self.predecessor_nodes[i]:
|
79 |
+
name = name + str(
|
80 |
+
sharding_spec_mapping[f'input{i}'].sharding_sequence)
|
81 |
+
sharding_strategy = self._get_sharding_strategy(
|
82 |
+
name=name,
|
83 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
84 |
+
communication_action_mapping={})
|
85 |
+
strategies_vector.append(sharding_strategy)
|
86 |
+
return strategies_vector
|
87 |
+
|
88 |
+
def _profile_sharding_cost(self, strategy, device_mesh):
|
89 |
+
updated_layer_attrs = {}
|
90 |
+
updated_input_values = {}
|
91 |
+
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
|
92 |
+
)
|
93 |
+
if self.layer.num_inputs >= 3 and self.layer.get_input(2) is not None:
|
94 |
+
updated_input_values[2] = shape
|
95 |
+
else:
|
96 |
+
updated_layer_attrs['shape'] = shape
|
97 |
+
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
98 |
+
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
99 |
+
device_mesh)
|
100 |
+
return elapsed_time
|
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/softmax_node.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
from tensorrt_llm._utils import trt_axes_to_dim
|
4 |
+
|
5 |
+
from .node import Node
|
6 |
+
from .sharding_strategy import StrategiesVector
|
7 |
+
|
8 |
+
|
9 |
+
class SoftMax(Node):
|
10 |
+
|
11 |
+
def __init__(self, layer):
|
12 |
+
super().__init__(layer)
|
13 |
+
layer.to_subclass()
|
14 |
+
self.softmax_dim = trt_axes_to_dim(layer.as_trt().axes)[0]
|
15 |
+
layer.to_base_class()
|
16 |
+
|
17 |
+
def _collect_strategies(self, device_mesh):
|
18 |
+
dim_partition_list = []
|
19 |
+
dim_size = len(self.op_data['input0'].shape)
|
20 |
+
dim_partition_list.append({})
|
21 |
+
dim_partition_list.extend(
|
22 |
+
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
23 |
+
dim_partition_list.extend(
|
24 |
+
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
25 |
+
dim_partition_list.extend(
|
26 |
+
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
27 |
+
dim_partition_list.extend(
|
28 |
+
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
29 |
+
strategies_vector = StrategiesVector(self)
|
30 |
+
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
31 |
+
for dim_partition_dict in dim_partition_list:
|
32 |
+
if self.softmax_dim in dim_partition_dict:
|
33 |
+
dim_partition_dict.pop(self.softmax_dim)
|
34 |
+
|
35 |
+
in0_partition_dict = dim_partition_dict
|
36 |
+
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
37 |
+
dim_partition_dict_mapping = {
|
38 |
+
"input0": in0_partition_dict,
|
39 |
+
"output0": out_partition_dict,
|
40 |
+
}
|
41 |
+
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
42 |
+
dim_partition_dict_mapping, device_mesh)
|
43 |
+
if 0 == len(sharding_spec_mapping):
|
44 |
+
continue
|
45 |
+
name = '{} = <softmax along dim {}> {}'.format(
|
46 |
+
sharding_spec_mapping['output0'].sharding_sequence,
|
47 |
+
self.softmax_dim,
|
48 |
+
sharding_spec_mapping['input0'].sharding_sequence)
|
49 |
+
sharding_strategy = self._get_sharding_strategy(
|
50 |
+
name=name,
|
51 |
+
sharding_spec_mapping=sharding_spec_mapping,
|
52 |
+
communication_action_mapping={})
|
53 |
+
strategies_vector.append(sharding_strategy)
|
54 |
+
return strategies_vector
|