aspctu commited on
Commit
5000658
1 Parent(s): 2223bef

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. lib.linux-x86_64-cpython-310/tensorrt_llm/__init__.py +96 -0
  3. lib.linux-x86_64-cpython-310/tensorrt_llm/_common.py +268 -0
  4. lib.linux-x86_64-cpython-310/tensorrt_llm/_ipc_utils.py +139 -0
  5. lib.linux-x86_64-cpython-310/tensorrt_llm/_utils.py +525 -0
  6. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/__init__.py +9 -0
  7. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/auto_parallel.py +263 -0
  8. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/cluster_info.py +556 -0
  9. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/config.py +61 -0
  10. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/device_mesh.py +612 -0
  11. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/node_graph.py +347 -0
  12. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py +0 -0
  13. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/pipeline_graph.py +1035 -0
  14. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/runtime_profiling.py +150 -0
  15. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/shape_info.py +362 -0
  16. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/simplifier.py +837 -0
  17. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/solver.py +641 -0
  18. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py +0 -0
  19. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py +41 -0
  20. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py +34 -0
  21. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py +45 -0
  22. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py +58 -0
  23. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py +56 -0
  24. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py +45 -0
  25. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py +49 -0
  26. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py +59 -0
  27. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py +196 -0
  28. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py +56 -0
  29. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py +79 -0
  30. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py +798 -0
  31. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/node.py +376 -0
  32. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py +60 -0
  33. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py +79 -0
  34. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py +67 -0
  35. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py +40 -0
  36. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py +0 -0
  37. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py +27 -0
  38. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py +395 -0
  39. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py +11 -0
  40. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py +19 -0
  41. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py +28 -0
  42. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py +73 -0
  43. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py +56 -0
  44. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_consistency.py +832 -0
  45. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_node.py +41 -0
  46. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_spec.py +418 -0
  47. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_strategy.py +77 -0
  48. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shuffle_node.py +238 -0
  49. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/slice_node.py +100 -0
  50. lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/softmax_node.py +54 -0
.gitattributes CHANGED
@@ -34,3 +34,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tensorrt_llm-0.12.0.dev2024072300-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tensorrt_llm-0.12.0.dev2024072300-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ lib.linux-x86_64-cpython-310/tensorrt_llm/bin/executorWorker filter=lfs diff=lfs merge=lfs -text
38
+ lib.linux-x86_64-cpython-310/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
+ lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libdecoder_attention.so filter=lfs diff=lfs merge=lfs -text
40
+ lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so filter=lfs diff=lfs merge=lfs -text
41
+ lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libtensorrt_llm.so filter=lfs diff=lfs merge=lfs -text
42
+ lib.linux-x86_64-cpython-310/tensorrt_llm/libs/libtensorrt_llm_nvrtc_wrapper.so filter=lfs diff=lfs merge=lfs -text
lib.linux-x86_64-cpython-310/tensorrt_llm/__init__.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ def _add_trt_llm_dll_directory():
18
+ import platform
19
+ on_windows = platform.system() == "Windows"
20
+ if on_windows:
21
+ import os
22
+ import sysconfig
23
+ from pathlib import Path
24
+ os.add_dll_directory(
25
+ Path(sysconfig.get_paths()['purelib']) / "tensorrt_llm" / "libs")
26
+
27
+
28
+ _add_trt_llm_dll_directory()
29
+
30
+ import sys
31
+
32
+ import tensorrt_llm.functional as functional
33
+ import tensorrt_llm.models as models
34
+ import tensorrt_llm.quantization as quantization
35
+ import tensorrt_llm.runtime as runtime
36
+ import tensorrt_llm.tools as tools
37
+
38
+ from ._common import _init, default_net, default_trtnet, precision
39
+ # Disable flake8 on the line below because mpi_barrier is not used in tensorrt_llm project
40
+ # but may be called in dependencies (such as examples)
41
+ from ._utils import mpi_barrier # NOQA
42
+ from ._utils import str_dtype_to_torch # NOQA
43
+ from ._utils import (mpi_rank, mpi_world_size, str_dtype_to_trt,
44
+ torch_dtype_to_trt)
45
+ from .auto_parallel import AutoParallelConfig, auto_parallel
46
+ from .builder import BuildConfig, Builder, BuilderConfig, build
47
+ from .functional import Tensor, constant
48
+ from .hlapi.llm import LLM, LlmArgs, SamplingParams
49
+ from .logger import logger
50
+ from .mapping import Mapping
51
+ from .module import Module
52
+ from .network import Network, net_guard
53
+ from .parameter import Parameter
54
+ from .version import __version__
55
+
56
+ __all__ = [
57
+ 'logger',
58
+ 'str_dtype_to_trt',
59
+ 'torch_dtype_to_trt',
60
+ 'str_dtype_to_torch'
61
+ 'mpi_barrier',
62
+ 'mpi_rank',
63
+ 'mpi_world_size',
64
+ 'constant',
65
+ 'default_net',
66
+ 'default_trtnet',
67
+ 'precision',
68
+ 'net_guard',
69
+ 'Network',
70
+ 'Mapping',
71
+ 'Builder',
72
+ 'BuilderConfig',
73
+ 'build',
74
+ 'BuildConfig',
75
+ 'Tensor',
76
+ 'Parameter',
77
+ 'runtime',
78
+ 'Module',
79
+ 'functional',
80
+ 'models',
81
+ 'auto_parallel',
82
+ 'AutoParallelConfig',
83
+ 'quantization',
84
+ 'tools',
85
+ 'LLM',
86
+ 'LlmArgs',
87
+ 'SamplingParams',
88
+ 'KvCacheConfig',
89
+ '__version__',
90
+ ]
91
+
92
+ _init(log_level="error")
93
+
94
+ print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}")
95
+
96
+ sys.stdout.flush()
lib.linux-x86_64-cpython-310/tensorrt_llm/_common.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import contextlib
16
+ import ctypes
17
+ import os
18
+ import platform
19
+ import time
20
+ from functools import wraps
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+ # isort: off
26
+ import torch
27
+ import tensorrt as trt
28
+
29
+ # isort: on
30
+
31
+ from ._utils import str_dtype_to_trt
32
+ from .bindings import MpiComm
33
+ from .logger import logger
34
+ from .plugin import _load_plugin_lib
35
+
36
+ net = None
37
+
38
+ _inited = False
39
+
40
+
41
+ def _init(log_level: object = None) -> None:
42
+ global _inited
43
+ if _inited:
44
+ return
45
+ _inited = True
46
+ # Move to __init__
47
+ if log_level is not None:
48
+ logger.set_level(log_level)
49
+
50
+ if os.getenv("TRT_LLM_NO_LIB_INIT", "0") == "1":
51
+ logger.info('Skipping TensorRT-LLM init.')
52
+ return
53
+
54
+ logger.info('Starting TensorRT-LLM init.')
55
+
56
+ # load plugin lib
57
+ _load_plugin_lib()
58
+
59
+ # load FT decoder layer
60
+ project_dir = str(Path(__file__).parent.absolute())
61
+ if platform.system() == "Windows":
62
+ ft_decoder_lib = project_dir + '/libs/th_common.dll'
63
+ else:
64
+ ft_decoder_lib = project_dir + '/libs/libth_common.so'
65
+ try:
66
+ torch.classes.load_library(ft_decoder_lib)
67
+ except Exception as e:
68
+ msg = '\nFATAL: Decoding operators failed to load. This may be caused by the incompatibility between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM.'
69
+ raise ImportError(str(e) + msg)
70
+
71
+ MpiComm.local_init()
72
+
73
+ logger.info('TensorRT-LLM inited.')
74
+
75
+
76
+ def default_net():
77
+ assert net, "Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
78
+ return net
79
+
80
+
81
+ def default_trtnet():
82
+ return default_net().trt_network
83
+
84
+
85
+ def set_network(network):
86
+ global net
87
+ net = network
88
+
89
+
90
+ def switch_net_dtype(cur_dtype):
91
+ prev_dtype = default_net().dtype
92
+ default_net().dtype = cur_dtype
93
+ return prev_dtype
94
+
95
+
96
+ @contextlib.contextmanager
97
+ def precision(dtype):
98
+ if isinstance(dtype, str):
99
+ dtype = str_dtype_to_trt(dtype)
100
+ prev_dtype = switch_net_dtype(dtype)
101
+ yield
102
+ switch_net_dtype(prev_dtype)
103
+
104
+
105
+ def serialize_engine(engine, path):
106
+ logger.info(f'Serializing engine to {path}...')
107
+ tik = time.time()
108
+ if isinstance(engine, trt.ICudaEngine):
109
+ engine = engine.serialize()
110
+ with open(path, 'wb') as f:
111
+ f.write(engine)
112
+ tok = time.time()
113
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
114
+ logger.info(f'Engine serialized. Total time: {t}')
115
+
116
+
117
+ def deserialize_engine(path):
118
+ runtime = trt.Runtime(logger.trt_logger)
119
+ with open(path, 'rb') as f:
120
+ logger.info(f'Loading engine from {path}...')
121
+ tik = time.time()
122
+
123
+ engine = runtime.deserialize_cuda_engine(f.read())
124
+ assert engine is not None
125
+
126
+ tok = time.time()
127
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
128
+ logger.info(f'Engine loaded. Total time: {t}')
129
+ return engine
130
+
131
+
132
+ _field_dtype_to_np_dtype_dict = {
133
+ trt.PluginFieldType.FLOAT16: np.float16,
134
+ trt.PluginFieldType.FLOAT32: np.float32,
135
+ trt.PluginFieldType.FLOAT64: np.float64,
136
+ trt.PluginFieldType.INT8: np.int8,
137
+ trt.PluginFieldType.INT16: np.int16,
138
+ trt.PluginFieldType.INT32: np.int32,
139
+ }
140
+
141
+
142
+ def field_dtype_to_np_dtype(dtype):
143
+ ret = _field_dtype_to_np_dtype_dict.get(dtype)
144
+ assert ret is not None, f'Unsupported dtype: {dtype}'
145
+ return ret
146
+
147
+
148
+ def convert_capsule_to_void_p(capsule):
149
+ ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
150
+ ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
151
+ ctypes.py_object, ctypes.c_char_p
152
+ ]
153
+ return ctypes.pythonapi.PyCapsule_GetPointer(capsule, None)
154
+
155
+
156
+ def get_nparray_from_void_p(void_pointer, elem_size, field_dtype):
157
+ ctypes.pythonapi.PyMemoryView_FromMemory.restype = ctypes.py_object
158
+ ctypes.pythonapi.PyMemoryView_FromMemory.argtypes = [
159
+ ctypes.c_char_p, ctypes.c_ssize_t, ctypes.c_int
160
+ ]
161
+ logger.info(
162
+ f'get_nparray: pointer = {void_pointer}, elem_size = {elem_size}')
163
+ char_pointer = ctypes.cast(void_pointer, ctypes.POINTER(ctypes.c_char))
164
+ np_dtype = field_dtype_to_np_dtype(field_dtype)
165
+ buf_bytes = elem_size * np.dtype(np_dtype).itemsize
166
+ logger.info(f'get_nparray: buf_bytes = {buf_bytes}')
167
+ mem_view = ctypes.pythonapi.PyMemoryView_FromMemory(
168
+ char_pointer, buf_bytes, 0) # number 0 represents PyBUF_READ
169
+ logger.info(
170
+ f'get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}')
171
+ buf = np.frombuffer(mem_view, np_dtype)
172
+ return buf
173
+
174
+
175
+ def get_scalar_from_field(field):
176
+ void_p = convert_capsule_to_void_p(field.data)
177
+ np_array = get_nparray_from_void_p(void_p, 1, field.type)
178
+ return np_array[0]
179
+
180
+
181
+ class _BuildingFlag:
182
+
183
+ def __enter__(self):
184
+ os.environ['IS_BUILDING'] = '1'
185
+ os.environ[
186
+ '__LUNOWUD'] = '-cask_fusion:fp8=off' # will be removed in future releases
187
+
188
+ def __exit__(self, type, value, tb):
189
+ del os.environ['IS_BUILDING']
190
+ del os.environ['__LUNOWUD']
191
+
192
+
193
+ def _is_building(f):
194
+ '''Use this to decorate functions which are called during engine building/refitting process,
195
+ otherwise, the plugin registration will fail.
196
+ '''
197
+
198
+ @wraps(f)
199
+ def decorated(*args, **kwargs):
200
+ with _BuildingFlag():
201
+ return f(*args, **kwargs)
202
+
203
+ return decorated
204
+
205
+
206
+ def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
207
+ max_input_len, max_seq_len, max_beam_width,
208
+ remove_input_padding, enable_context_fmha,
209
+ tokens_per_block, multiple_profiles):
210
+ if not remove_input_padding:
211
+ if max_num_tokens is not None or opt_num_tokens is not None:
212
+ max_num_tokens = max_batch_size * max_seq_len
213
+ logger.warning("remove_input_padding is not enabled, the specified "
214
+ "max_num_tokens/opt_num_tokens will be ignored.")
215
+ return max_num_tokens, opt_num_tokens
216
+ else:
217
+ if max_num_tokens is None:
218
+ max_num_tokens = max_seq_len * max_batch_size
219
+ logger.warning(
220
+ "remove_input_padding is enabled, while max_num_tokens "
221
+ "is not set, setting to max_batch_size*max_seq_len. \n"
222
+ "It may not be optimal to set max_num_tokens=max_batch_size*max_seq_len "
223
+ "when remove_input_padding is enabled, because the number "
224
+ "of packed input tokens are very likely to be smaller, "
225
+ "we strongly recommend to set max_num_tokens according "
226
+ "to your workloads.")
227
+ if opt_num_tokens is None and not multiple_profiles:
228
+ opt_num_tokens = min(max_batch_size * max_beam_width,
229
+ max_num_tokens)
230
+ logger.warning(
231
+ "remove_input_padding is enabled, while opt_num_tokens "
232
+ "is not set, setting to max_batch_size*max_beam_width. \n")
233
+ if max_num_tokens > 16384:
234
+ logger.warning(
235
+ "Specifying a `max_num_tokens` larger than 16384 is usually "
236
+ "not recommended, we do not expect perf gain with that and too "
237
+ "large `max_num_tokens` could possibly exceed the TensorRT "
238
+ "tensor volume, causing runtime errors. "
239
+ f"Got `max_num_tokens` = {max_num_tokens}")
240
+ if max_num_tokens > max_seq_len * max_batch_size:
241
+ max_num_tokens = max_seq_len * max_batch_size
242
+ logger.warning(
243
+ f"max_num_tokens ({max_num_tokens}) shouldn't be greater than "
244
+ f"max_seq_len * max_batch_size ({max_seq_len * max_batch_size}), "
245
+ f"specifying to max_seq_len * max_batch_size ({max_seq_len * max_batch_size})."
246
+ )
247
+ if max_num_tokens < max_input_len and not enable_context_fmha:
248
+ logger.warning(
249
+ f"When enable_context_fmha is not turned on, max_num_tokens ({max_num_tokens}) "
250
+ f"should be at least max_input_len ({max_input_len}), specifying to "
251
+ f"max_input_len ({max_input_len}).")
252
+ max_num_tokens = max_input_len
253
+ elif max_num_tokens < tokens_per_block and enable_context_fmha:
254
+ logger.warning(
255
+ f"When enable_context_fmha is turned on, max_num_tokens ({max_num_tokens}) "
256
+ f"should be at least tokens_per_block ({tokens_per_block}), specifying to "
257
+ f"tokens_per_block ({tokens_per_block}). At this time, you also need to enable "
258
+ f"context chunking at runtime, otherwise you may encounter errors.")
259
+ max_num_tokens = tokens_per_block
260
+
261
+ if opt_num_tokens is not None and opt_num_tokens > max_num_tokens:
262
+ logger.warning(
263
+ f"opt_num_tokens ({opt_num_tokens}) shouldn't be greater than "
264
+ f"max_num_tokens ({max_num_tokens}), "
265
+ f"specifying to max_num_tokens ({max_num_tokens}).")
266
+ opt_num_tokens = max_num_tokens
267
+
268
+ return max_num_tokens, opt_num_tokens
lib.linux-x86_64-cpython-310/tensorrt_llm/_ipc_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import array
16
+ import struct
17
+ import sys
18
+ from contextlib import contextmanager
19
+ from typing import List, Tuple
20
+
21
+ from cuda import cudart
22
+ from cuda.cudart import cudaError_t
23
+
24
+ from ._utils import mpi_comm
25
+ from .mapping import Mapping
26
+
27
+
28
+ def _raise_if_error(error: cudaError_t):
29
+ if error != cudaError_t.cudaSuccess:
30
+ raise RuntimeError(error)
31
+
32
+
33
+ @contextmanager
34
+ def peer_access(mapping: Mapping):
35
+ set_peer_access(mapping, True)
36
+ try:
37
+ yield
38
+ finally:
39
+ set_peer_access(mapping, False)
40
+
41
+
42
+ def set_peer_access(mapping: Mapping, enabled: bool = True):
43
+ src_node = mapping.local_rank
44
+ for rank in mapping.tp_group:
45
+ dest_node = mapping.get_local_rank(rank)
46
+ if mapping.get_node_rank(
47
+ rank) != mapping.node_rank or dest_node == src_node:
48
+ continue
49
+
50
+ error, result = cudart.cudaDeviceCanAccessPeer(src_node, dest_node)
51
+ _raise_if_error(error)
52
+
53
+ if result == 0:
54
+ raise RuntimeError(
55
+ f"Can't enable access between nodes {src_node} and {dest_node}")
56
+
57
+ if enabled:
58
+ cudart.cudaDeviceEnablePeerAccess(dest_node, 0)
59
+ else:
60
+ cudart.cudaDeviceDisablePeerAccess(dest_node)
61
+ error = cudart.cudaGetLastError()[0]
62
+ if error not in [
63
+ cudaError_t.cudaSuccess,
64
+ cudaError_t.cudaErrorPeerAccessAlreadyEnabled,
65
+ cudaError_t.cudaErrorPeerAccessNotEnabled
66
+ ]:
67
+ raise RuntimeError(error)
68
+
69
+
70
+ class IpcMemory():
71
+
72
+ # WARNING: Must in sync with FLAGS_SIZE in cpp/include/tensorrt_llm/runtime/ipcUtils.h
73
+ # (Max all reduce blocks + 1) * sizeof(int)
74
+ IPC_BARRIERS_SIZE_PER_GPU = (24 + 1) * 4
75
+
76
+ def __init__(self, mapping: Mapping, size: int):
77
+ self.mapping = mapping
78
+ self.open_ipc = mapping.tp_size <= mapping.gpus_per_node
79
+ if self.open_ipc:
80
+ self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(
81
+ self.mapping, size, True)
82
+ else:
83
+ self.peer_ptrs = [0] * mapping.tp_size
84
+ self.local_ptr = 0
85
+
86
+ def __del__(self):
87
+ if not sys.is_finalizing() and self.open_ipc:
88
+ IpcMemory.close_ipc_memory(self.mapping, self.peer_ptrs)
89
+
90
+ def serialize(self) -> List[int]:
91
+ buffer = bytes(0)
92
+ for ptr in self.peer_ptrs:
93
+ buffer += struct.pack("P", ptr)
94
+
95
+ return array.array("Q", buffer).tolist()
96
+
97
+ @staticmethod
98
+ def open_ipc_memory(mapping: Mapping,
99
+ size: int,
100
+ set_to_zero: bool = False) -> Tuple[List[int], int]:
101
+ """ Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
102
+ Returns a list of buffer pointers, buffers[i] is a handle to the corresponding buffer residing on GPU #i.
103
+ Call close_ipc_handle with the *buffer*.
104
+ """
105
+ comm = mpi_comm().Split(mapping.pp_rank, mapping.tp_rank)
106
+
107
+ error, local_ptr = cudart.cudaMalloc(size)
108
+ _raise_if_error(error)
109
+ if set_to_zero:
110
+ _raise_if_error(cudart.cudaMemset(local_ptr, 0, size)[0])
111
+ error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
112
+ _raise_if_error(error)
113
+
114
+ handles_reserved = comm.allgather(local_handle.reserved)
115
+ handles = []
116
+ for reserved in handles_reserved:
117
+ handle = cudart.cudaIpcMemHandle_t()
118
+ handle.reserved = reserved
119
+ handles.append(handle)
120
+
121
+ peer_ptrs = []
122
+ for node, handle in enumerate(handles):
123
+ if node == mapping.tp_rank:
124
+ peer_ptrs.append(local_ptr)
125
+ else:
126
+ error, ptr = cudart.cudaIpcOpenMemHandle(
127
+ handle, cudart.cudaIpcMemLazyEnablePeerAccess)
128
+ _raise_if_error(error)
129
+ peer_ptrs.append(ptr)
130
+
131
+ return peer_ptrs, local_ptr
132
+
133
+ @staticmethod
134
+ def close_ipc_memory(mapping: Mapping, peer_ptrs: List[int]):
135
+ for node, ptr in enumerate(peer_ptrs):
136
+ if node == mapping.tp_rank:
137
+ _raise_if_error(cudart.cudaFree(ptr)[0])
138
+ else:
139
+ _raise_if_error(cudart.cudaIpcCloseMemHandle(ptr)[0])
lib.linux-x86_64-cpython-310/tensorrt_llm/_utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import gc
17
+ import inspect
18
+ import json
19
+ import math
20
+ import struct
21
+ import weakref
22
+ from dataclasses import asdict
23
+ from enum import EnumMeta
24
+ from functools import partial
25
+ from typing import Any, Dict, List, Optional, Union
26
+
27
+ import numpy as np
28
+ from packaging import version
29
+
30
+ from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
31
+
32
+ # isort: off
33
+ import torch
34
+ import tensorrt as trt
35
+ # isort: on
36
+
37
+ # numpy doesn't know bfloat16, define abstract binary type instead
38
+ np_bfloat16 = np.dtype('V2', metadata={"dtype": "bfloat16"})
39
+ np_float8 = np.dtype('V1', metadata={"dtype": "float8"})
40
+
41
+
42
+ def torch_to_numpy(x: torch.Tensor):
43
+ assert isinstance(x, torch.Tensor), \
44
+ f'x must be a torch.Tensor object, but got {type(x)}.'
45
+ if x.dtype == torch.bfloat16:
46
+ return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16)
47
+ elif x.dtype == torch.float8_e4m3fn:
48
+ return x.view(torch.int8).detach().cpu().numpy().view(np_float8)
49
+ else:
50
+ return x.detach().cpu().numpy()
51
+
52
+
53
+ def numpy_to_torch(x):
54
+ if x.dtype == np_bfloat16:
55
+ return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16)
56
+ elif x.dtype == np_float8:
57
+ return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn)
58
+ else:
59
+ return torch.from_numpy(x)
60
+
61
+
62
+ def numpy_to_dtype(x, dtype: str):
63
+ if str_dtype_to_np(dtype) == x.dtype:
64
+ return x
65
+ if x.dtype not in [np_bfloat16, np_float8
66
+ ] and dtype not in ['bfloat16', 'fp8']:
67
+ return x.astype(str_dtype_to_np(dtype))
68
+ else:
69
+ return torch_to_numpy(numpy_to_torch(x).to(str_dtype_to_torch(dtype)))
70
+
71
+
72
+ fp32_array = partial(np.array, dtype=np.float32)
73
+ fp16_array = partial(np.array, dtype=np.float16)
74
+ int32_array = partial(np.array, dtype=np.int32)
75
+ int64_array = partial(np.array, dtype=np.int64)
76
+ bool_array = partial(np.array, dtype=np.bool_)
77
+
78
+
79
+ def dims_array(x):
80
+ is_int64_dims = True
81
+ try:
82
+ trt.Dims([np.iinfo(np.int64).max])
83
+ except TypeError:
84
+ is_int64_dims = False
85
+ return int64_array(x) if is_int64_dims else int32_array(x)
86
+
87
+
88
+ def bf16_array(x):
89
+ x = torch.tensor(x, dtype=torch.bfloat16)
90
+ x = torch_to_numpy(x)
91
+ return x
92
+
93
+
94
+ def numpy_array(data, trt_dtype):
95
+ # convenient wrapper due to numpy not support bf16 yet
96
+ if trt_dtype == trt.bfloat16:
97
+ return bf16_array(data)
98
+ return np.array(data, trt_dtype_to_np(trt_dtype))
99
+
100
+
101
+ def copy_torch_to_numpy(x: torch.Tensor, ndarray: np.array):
102
+ if x.dtype == torch.bfloat16:
103
+ torch.from_numpy(ndarray.view(np.int16)).copy_(x.view(torch.int16))
104
+ elif x.dtype == torch.float8_e4m3fn:
105
+ torch.from_numpy(ndarray.view(np.int8)).copy_(x.view(torch.int8))
106
+ else:
107
+ torch.from_numpy(ndarray).copy_(x)
108
+ return ndarray
109
+
110
+
111
+ def trt_version():
112
+ return trt.__version__
113
+
114
+
115
+ # TRT supports strongly_typed in 9.1
116
+ def support_strongly_type():
117
+ return version.parse(trt_version()) >= version.parse("9.1.0")
118
+
119
+
120
+ # Check if TRT version >= 10
121
+ def trt_gte_10():
122
+ return version.parse(trt_version()).major > 9
123
+
124
+
125
+ # Check if TRT version >= 10.1
126
+ def trt_gte_10_1():
127
+ trt_ver = version.parse(trt_version())
128
+ return trt_ver.major > 9 and trt_ver.minor > 0
129
+
130
+
131
+ # Check if TRT version >= 10.2
132
+ def trt_gte_10_2():
133
+ ver = version.parse(trt_version())
134
+ return (ver.major * 10 + ver.minor) >= 102
135
+
136
+
137
+ def torch_version():
138
+ return torch.__version__
139
+
140
+
141
+ _str_to_np_dict = dict(
142
+ float16=np.float16,
143
+ float32=np.float32,
144
+ int64=np.int64,
145
+ int32=np.int32,
146
+ int8=np.int8,
147
+ bool=np.bool_,
148
+ bfloat16=np_bfloat16,
149
+ fp8=np_float8,
150
+ )
151
+
152
+
153
+ def str_dtype_to_np(dtype):
154
+ ret = _str_to_np_dict.get(dtype)
155
+ assert ret is not None, f'Unsupported dtype: {dtype}'
156
+ return ret
157
+
158
+
159
+ _str_to_torch_dtype_dict = dict(
160
+ bfloat16=torch.bfloat16,
161
+ float16=torch.float16,
162
+ float32=torch.float32,
163
+ int64=torch.int64,
164
+ int32=torch.int32,
165
+ int8=torch.int8,
166
+ bool=torch.bool,
167
+ fp8=torch.float8_e4m3fn,
168
+ )
169
+
170
+
171
+ def str_dtype_to_torch(dtype):
172
+ ret = _str_to_torch_dtype_dict.get(dtype)
173
+ assert ret is not None, f'Unsupported dtype: {dtype}'
174
+ return ret
175
+
176
+
177
+ _torch_dtype_to_str_dict = {v: k for k, v in _str_to_torch_dtype_dict.items()}
178
+
179
+
180
+ def torch_dtype_to_str(dtype):
181
+ return _torch_dtype_to_str_dict[dtype]
182
+
183
+
184
+ _str_to_trt_dtype_dict = dict(float16=trt.float16,
185
+ float32=trt.float32,
186
+ int64=trt.int64,
187
+ int32=trt.int32,
188
+ int8=trt.int8,
189
+ bool=trt.bool,
190
+ bfloat16=trt.bfloat16,
191
+ fp8=trt.fp8)
192
+
193
+
194
+ def str_dtype_to_trt(dtype):
195
+ ret = _str_to_trt_dtype_dict.get(dtype)
196
+ assert ret is not None, f'Unsupported dtype: {dtype}'
197
+ return ret
198
+
199
+
200
+ _trt_to_str_dtype_dict = {v: k for k, v in _str_to_trt_dtype_dict.items()}
201
+
202
+
203
+ def trt_dtype_to_str(dtype: trt.DataType) -> str:
204
+ assert isinstance(dtype, trt.DataType)
205
+ return _trt_to_str_dtype_dict[dtype]
206
+
207
+
208
+ _np_to_trt_dtype_dict = {
209
+ np.int8: trt.int8,
210
+ np.int32: trt.int32,
211
+ np.int64: trt.int64,
212
+ np.float16: trt.float16,
213
+ np.float32: trt.float32,
214
+ np.bool_: trt.bool,
215
+
216
+ # hash of np.dtype('int32') != np.int32
217
+ np.dtype('int8'): trt.int8,
218
+ np.dtype('int32'): trt.int32,
219
+ np.dtype('int64'): trt.int64,
220
+ np.dtype('float16'): trt.float16,
221
+ np.dtype('float32'): trt.float32,
222
+ np.dtype('bool'): trt.bool,
223
+ np_bfloat16: trt.bfloat16,
224
+ np_float8: trt.fp8,
225
+ }
226
+
227
+
228
+ def np_dtype_to_trt(dtype):
229
+ ret = _np_to_trt_dtype_dict.get(dtype)
230
+ assert ret is not None, f'Unsupported dtype: {dtype}'
231
+ return ret
232
+
233
+
234
+ _trt_to_np_dtype_dict = {
235
+ trt.int8: np.int8,
236
+ trt.int32: np.int32,
237
+ trt.int64: np.int64,
238
+ trt.float16: np.float16,
239
+ trt.float32: np.float32,
240
+ trt.bool: np.bool_,
241
+ trt.bfloat16: np_bfloat16,
242
+ trt.fp8: np_float8,
243
+ }
244
+
245
+
246
+ def trt_dtype_to_np(dtype):
247
+ ret = _trt_to_np_dtype_dict.get(dtype)
248
+ assert ret is not None, f'Unsupported dtype: {dtype}'
249
+ return ret
250
+
251
+
252
+ _torch_to_np_dtype_dict = {
253
+ torch.bool: np.bool_,
254
+ torch.uint8: np.uint8,
255
+ torch.int8: np.int8,
256
+ torch.int16: np.int16,
257
+ torch.int32: np.int32,
258
+ torch.int64: np.int64,
259
+ torch.float16: np.float16,
260
+ torch.bfloat16: np_bfloat16,
261
+ torch.float8_e4m3fn: np_float8,
262
+ torch.float32: np.float32,
263
+ torch.float64: np.float64,
264
+ torch.complex64: np.complex64,
265
+ torch.complex128: np.complex128,
266
+ }
267
+
268
+
269
+ def torch_dtype_to_np(dtype):
270
+ ret = _torch_to_np_dtype_dict.get(dtype)
271
+ assert ret is not None, f'Unsupported dtype: {dtype}'
272
+ return ret
273
+
274
+
275
+ _trt_to_torch_dtype_dict = {
276
+ trt.float16: torch.float16,
277
+ trt.float32: torch.float32,
278
+ trt.int64: torch.int64,
279
+ trt.int32: torch.int32,
280
+ trt.int8: torch.int8,
281
+ trt.bool: torch.bool,
282
+ trt.bfloat16: torch.bfloat16,
283
+ trt.fp8: torch.float8_e4m3fn,
284
+ }
285
+
286
+
287
+ def trt_dtype_to_torch(dtype):
288
+ ret = _trt_to_torch_dtype_dict.get(dtype)
289
+ assert ret is not None, f'Unsupported dtype: {dtype}'
290
+ return ret
291
+
292
+
293
+ def is_same_dtype(type_a: Union[str, trt.DataType],
294
+ type_b: Union[str, trt.DataType]) -> bool:
295
+ if isinstance(type_a, str):
296
+ type_a = str_dtype_to_trt(type_a)
297
+
298
+ if isinstance(type_b, str):
299
+ type_b = str_dtype_to_trt(type_b)
300
+
301
+ return type_a == type_b
302
+
303
+
304
+ _torch_to_trt_dtype_dict = {
305
+ torch.float16: trt.float16,
306
+ torch.float32: trt.float32,
307
+ torch.int64: trt.int64,
308
+ torch.int32: trt.int32,
309
+ torch.int8: trt.int8,
310
+ torch.float8_e4m3fn: trt.fp8,
311
+ torch.qint8: trt.int8,
312
+ torch.bool: trt.bool,
313
+ torch.bfloat16: trt.bfloat16
314
+ }
315
+
316
+
317
+ def torch_dtype_to_trt(dtype):
318
+ ret = _torch_to_trt_dtype_dict.get(dtype)
319
+ assert ret is not None, f'Unsupported dtype: {dtype}'
320
+ return ret
321
+
322
+
323
+ def dim_to_trt_axes(dim):
324
+ """Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
325
+ if not isinstance(dim, tuple):
326
+ dim = (dim, )
327
+
328
+ # create axes bitmask for reduce layer
329
+ axes = 0
330
+ for d in dim:
331
+ axes |= 1 << d
332
+
333
+ return axes
334
+
335
+
336
+ def trt_axes_to_dim(axes: int) -> List[int]:
337
+ """Converts tensorrt axes bitmask to dims"""
338
+ dim = []
339
+ for i in range(32):
340
+ if axes & (1 << i):
341
+ dim.append(i)
342
+
343
+ return dim
344
+
345
+
346
+ def dim_resolve_negative(dim, ndim):
347
+ if not isinstance(dim, tuple):
348
+ dim = (dim, )
349
+ pos = []
350
+ for d in dim:
351
+ if d < 0:
352
+ d = ndim + d
353
+ pos.append(d)
354
+ return tuple(pos)
355
+
356
+
357
+ # mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
358
+ OMPI_COMM_TYPE_HOST = 9
359
+
360
+
361
+ def mpi_comm():
362
+ from mpi4py import MPI
363
+ return MPI.COMM_WORLD
364
+
365
+
366
+ def mpi_rank():
367
+ return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
368
+
369
+
370
+ def mpi_world_size():
371
+ return mpi_comm().Get_size() if ENABLE_MULTI_DEVICE else 1
372
+
373
+
374
+ def mpi_barrier():
375
+ mpi_comm().Barrier()
376
+
377
+
378
+ def mpi_broadcast(obj, root=0):
379
+ return mpi_comm().bcast(obj, root)
380
+
381
+
382
+ def pad_vocab_size(vocab_size, tp_size):
383
+ return int(math.ceil(vocab_size / tp_size) * tp_size)
384
+
385
+
386
+ def to_dict(obj):
387
+ return copy.deepcopy(obj.__dict__)
388
+
389
+
390
+ def to_json_string(obj):
391
+ if not isinstance(obj, dict):
392
+ obj = to_dict(obj)
393
+ return json.dumps(obj, indent=2, sort_keys=True) + "\n"
394
+
395
+
396
+ def to_json_file(obj, json_file_path):
397
+ with open(json_file_path, "w", encoding="utf-8") as writer:
398
+ writer.write(to_json_string(obj))
399
+
400
+
401
+ def numpy_fp32_to_bf16(src):
402
+ # Numpy doesn't support bfloat16 type
403
+ # Convert float32 to bfloat16 manually and assign with bf16 abstract type
404
+ original_shape = src.shape
405
+ src = src.flatten()
406
+ src = np.ascontiguousarray(src)
407
+
408
+ assert src.dtype == np.float32
409
+ dst = np.empty_like(src, dtype=np.uint16)
410
+ for i in range(len(dst)):
411
+ bytes = struct.pack('<f', src[i])
412
+ dst[i] = struct.unpack('<H', struct.pack('BB', bytes[2], bytes[3]))[0]
413
+ return dst.reshape(original_shape).view(np_bfloat16)
414
+
415
+
416
+ _extra_attrs_by_object: Dict[int, Dict[str, Any]] = {}
417
+
418
+
419
+ def get_extra_attr(obj, attr_name):
420
+ if id(obj) not in _extra_attrs_by_object:
421
+ return None
422
+ extra_attrs = _extra_attrs_by_object[id(obj)]
423
+ return extra_attrs.get(attr_name)
424
+
425
+
426
+ def _clean_extra_attrs(obj_id):
427
+ if obj_id in _extra_attrs_by_object:
428
+ del _extra_attrs_by_object[obj_id]
429
+
430
+
431
+ def set_extra_attr(obj, attr_name, value):
432
+ if id(obj) not in _extra_attrs_by_object:
433
+ _extra_attrs_by_object[id(obj)] = {}
434
+ weakref.finalize(obj, _clean_extra_attrs, id(obj))
435
+ _extra_attrs_by_object[id(obj)][attr_name] = value
436
+
437
+
438
+ def has_extra_attr(obj, attr_name):
439
+ if id(obj) not in _extra_attrs_by_object:
440
+ return False
441
+ return attr_name in _extra_attrs_by_object[id(obj)]
442
+
443
+
444
+ def set_obj_attrs(
445
+ obj: torch.Tensor,
446
+ ojb_attrs: Optional[Dict[str, Any]],
447
+ ):
448
+ """Set attributes on a object.
449
+
450
+ This method is used to set attributes on a object. This method
451
+ will not overwrite existing attributes.
452
+ """
453
+ if ojb_attrs is None:
454
+ return
455
+ for key, value in ojb_attrs.items():
456
+ assert not hasattr(
457
+ obj, key), (f"Overwriting existing tensor attribute: {key}")
458
+ setattr(obj, key, value)
459
+
460
+
461
+ def get_init_params(obj, cls=None):
462
+ """
463
+ Get all parameters in object's __init__.
464
+ Use cls's __init__ as filter if cls provided.
465
+ """
466
+ names = None
467
+ if cls is not None:
468
+ names = set(list(inspect.signature(cls.__init__).parameters)[1:])
469
+ return {
470
+ name: getattr(obj, name)
471
+ for name in list(inspect.signature(obj.__class__.__init__).parameters)
472
+ [1:] if names is None or name in names
473
+ }
474
+
475
+
476
+ def release_gc():
477
+ ''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
478
+ This could be used when some states might be kept in memory even after the variables are deleted.
479
+ '''
480
+ gc.collect()
481
+ if torch.cuda.is_available():
482
+ torch.cuda.empty_cache()
483
+ torch.cuda.ipc_collect()
484
+
485
+
486
+ class DictConversion:
487
+
488
+ @classmethod
489
+ def from_dict(cls, config: Dict[str, Any]):
490
+ obj = cls()
491
+ fields = obj.__dataclass_fields__
492
+ for key, value in config.items():
493
+ assert hasattr(obj, key)
494
+ field_cls = fields[key].type
495
+ if (isinstance(field_cls, type)
496
+ and issubclass(field_cls, DictConversion)
497
+ and isinstance(value, dict)):
498
+ value = field_cls.from_dict(value)
499
+ setattr(obj, key, value)
500
+ return obj
501
+
502
+ def to_dict(self):
503
+ return asdict(self)
504
+
505
+ @classmethod
506
+ def from_json_file(cls, file):
507
+ with open(file) as f:
508
+ return cls.from_dict(json.load(f))
509
+
510
+ def set_defaults(self, **kwargs):
511
+ for key, default in kwargs.items():
512
+ value = getattr(self, key)
513
+ if (value is None
514
+ or (isinstance(value, (list, dict)) and len(value) == 0)):
515
+ setattr(self, key, default)
516
+
517
+
518
+ class BaseEnumMeta(EnumMeta):
519
+
520
+ def __contains__(cls, item):
521
+ try:
522
+ cls(item)
523
+ except ValueError:
524
+ return False
525
+ return True
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .auto_parallel import auto_parallel
2
+ from .cluster_info import infer_cluster_config
3
+ from .config import AutoParallelConfig
4
+
5
+ __all__ = [
6
+ 'auto_parallel',
7
+ 'AutoParallelConfig',
8
+ 'infer_cluster_config',
9
+ ]
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/auto_parallel.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from pathlib import Path
5
+
6
+ import tensorrt as trt
7
+ import torch
8
+ from filelock import FileLock
9
+
10
+ from tensorrt_llm.functional import DimRange, Tensor
11
+ from tensorrt_llm.logger import logger
12
+ from tensorrt_llm.network import Network, net_guard
13
+
14
+ from .config import AutoParallelConfig
15
+ from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
16
+ from .node_graph import NodeGraph
17
+ from .parallelization import ParallelConfig, parallelize
18
+ from .pipeline_graph import PipelineGraph
19
+ from .simplifier import GraphConfig, Simplifier, StageType
20
+ from .utils import current_flags
21
+
22
+
23
+ def to_network(graph: PipelineGraph, network: Network):
24
+ logger.debug("Converting graph to network")
25
+ trt_network = graph.as_trt()
26
+ trt_network.name = network.trt_network.name
27
+ new_network = Network()
28
+ new_network._init(trt_network)
29
+ new_network._dtype = network._dtype
30
+ new_network._plugin_config = network._plugin_config
31
+ new_network._unfilled_weights = graph._unfilled_weights
32
+ new_network._auto_parallel_config = graph._auto_parallel_config
33
+ with net_guard(network):
34
+ for i in range(trt_network.num_inputs):
35
+ input = trt_network.get_input(i)
36
+ tensor = Tensor(is_network_input=False)
37
+ if input.name in network._inputs:
38
+ profiles = network._inputs[input.name].profiles
39
+ elif len(network._inputs) == 0:
40
+ profiles = []
41
+ else:
42
+ shape = input.shape
43
+ num_profiles = len(list(network._inputs.values())[0].profiles)
44
+ profile = DimRange(shape, [None] * len(shape))
45
+ profiles = [profile] * num_profiles
46
+ tensor.profiles = profiles
47
+ tensor.trt_tensor = input
48
+ new_network._inputs[input.name] = tensor
49
+ return new_network
50
+
51
+
52
+ def find_solution(
53
+ node_graph: NodeGraph,
54
+ graph_config: GraphConfig,
55
+ lmesh: LogicalDeviceMesh,
56
+ memory_budget: int,
57
+ flags: list,
58
+ device: int,
59
+ dump_path: str,
60
+ ) -> ParallelConfig:
61
+ torch.cuda.set_device(device)
62
+ with current_flags(*flags):
63
+ cost_graph = node_graph.get_cost_graph(lmesh)
64
+ num_stages = graph_config.num_stages
65
+ if num_stages == 1:
66
+ stage_types = [None]
67
+ elif num_stages == 2:
68
+ stage_types = [StageType.START, StageType.END]
69
+ else:
70
+ stage_types = [StageType.START, StageType.BLOCK, StageType.END]
71
+
72
+ best_config, best_solution = None, None
73
+ for stage_type in stage_types:
74
+ if stage_type is not None:
75
+ node_graph.set_slowest_stage(stage_type, graph_config)
76
+ solution = node_graph.find_solution(
77
+ cost_graph,
78
+ memory_budget,
79
+ )
80
+ cost = solution.total_cost
81
+ if best_config is None or cost < best_config.cost:
82
+ best_config = ParallelConfig()
83
+ best_config.graph_config = graph_config
84
+ best_config.lmesh = lmesh
85
+ best_config.cost = cost
86
+ best_config.graph_strategy = solution.node_best_strategy
87
+ best_config.stage_type = stage_type
88
+ best_solution = solution
89
+ if dump_path is not None:
90
+ lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
91
+ vlz_name = f"{dump_path}/solution."
92
+ if graph_config.num_micro_batches != 1:
93
+ vlz_name += f"mbs{graph_config.num_micro_batches}."
94
+ if graph_config.num_stages != 1:
95
+ vlz_name += f"stages{graph_config.num_stages}."
96
+ vlz_name += lmesh.cluster_key
97
+ with lock:
98
+ node_graph.visualize_solution(
99
+ best_solution,
100
+ vlz_name,
101
+ ignore_shape_io=True,
102
+ )
103
+ return best_config
104
+
105
+
106
+ def infer_builder_flags(network):
107
+ fp16_enabled = False
108
+ bf16_enabled = False
109
+ int8_enabled = False
110
+ fp8_enabled = False
111
+
112
+ def check_dtype(tensor):
113
+ nonlocal fp16_enabled
114
+ nonlocal bf16_enabled
115
+ nonlocal int8_enabled
116
+ nonlocal fp8_enabled
117
+ if tensor.dtype == trt.DataType.HALF:
118
+ fp16_enabled = True
119
+ elif tensor.dtype == trt.DataType.BF16:
120
+ bf16_enabled = True
121
+ elif tensor.dtype == trt.DataType.INT8:
122
+ int8_enabled = True
123
+ elif tensor.dtype == trt.DataType.FP8:
124
+ fp8_enabled = True
125
+
126
+ trt_network = network.trt_network
127
+ for i in range(trt_network.num_inputs):
128
+ input = trt_network.get_input(i)
129
+ check_dtype(input)
130
+ for i in range(trt_network.num_layers):
131
+ layer = trt_network.get_layer(i)
132
+ for j in range(layer.num_outputs):
133
+ output = layer.get_output(j)
134
+ check_dtype(output)
135
+
136
+ builder_flags = 0
137
+ if fp16_enabled:
138
+ builder_flags |= 1 << int(trt.BuilderFlag.FP16)
139
+ builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
140
+ if bf16_enabled:
141
+ builder_flags |= 1 << int(trt.BuilderFlag.BF16)
142
+ builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
143
+ if int8_enabled:
144
+ builder_flags |= 1 << int(trt.BuilderFlag.INT8)
145
+ if fp8_enabled:
146
+ builder_flags |= 1 << int(trt.BuilderFlag.FP8)
147
+ builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
148
+ return builder_flags
149
+
150
+
151
+ def auto_parallel(network: Network, config: AutoParallelConfig):
152
+ debug_mode = config.debug_mode
153
+ memory_budget = config.get_cluster_info(
154
+ ).memory_budget_per_device * 1024 * 1024 * 1024
155
+ enable_pipeline_parallelism = config.enable_pipeline_parallelism
156
+ if config.world_size < config.gpus_per_node:
157
+ num_hosts = 1
158
+ num_devices_per_host = config.world_size
159
+ else:
160
+ assert config.world_size % config.gpus_per_node == 0
161
+ num_hosts = config.world_size // config.gpus_per_node
162
+ num_devices_per_host = config.gpus_per_node
163
+ parallel_config_cache = config.parallel_config_cache
164
+ dump_path = config.dump_path if debug_mode else None
165
+ fill_weights = config.fill_weights
166
+
167
+ if num_hosts == 1 and num_devices_per_host == 1:
168
+ return [network]
169
+
170
+ if dump_path is not None:
171
+ if not os.path.exists(dump_path):
172
+ os.makedirs(dump_path)
173
+
174
+ builder_flags = config.builder_flags or infer_builder_flags(network)
175
+ flags = [builder_flags, network.strongly_typed]
176
+ with current_flags(*flags):
177
+ simplifier = Simplifier(network, config)
178
+ network_hash = simplifier.get_network_hash()
179
+
180
+ best_config = None
181
+ if parallel_config_cache is not None and Path(
182
+ parallel_config_cache).exists():
183
+ parallel_config = ParallelConfig.from_file(parallel_config_cache)
184
+ if (ParallelConfig.VERSION == parallel_config.version
185
+ and network_hash == parallel_config.network_hash
186
+ and config == parallel_config.auto_parallel_config):
187
+ logger.info(
188
+ f"use cache of parallel config from {parallel_config_cache}"
189
+ )
190
+ best_config = parallel_config
191
+
192
+ if best_config is None:
193
+ num_devices = num_hosts * num_devices_per_host
194
+ phy_ids = [[
195
+ i + j * num_devices_per_host
196
+ for i in range(num_devices_per_host)
197
+ ] for j in range(num_hosts)]
198
+ phy_mesh = PhysicalDeviceMesh(phy_ids, config)
199
+ if enable_pipeline_parallelism:
200
+ num_micro_batches_list = simplifier.list_all_num_micro_batches()
201
+ else:
202
+ num_micro_batches_list = [1]
203
+
204
+ jobs = []
205
+ for num_micro_batches in num_micro_batches_list:
206
+ simplifier.infer_shapes(num_micro_batches)
207
+ if enable_pipeline_parallelism:
208
+ pipeline_configs = phy_mesh.list_all_pipeline_configs()
209
+ else:
210
+ pipeline_configs = [(1, num_devices)]
211
+ for num_stages, num_devices_per_stage in pipeline_configs:
212
+ # TODO: add fallback path that allows num_micro_batches >= num_stages
213
+ # if no solution satisfies memory budget
214
+ if num_micro_batches < num_stages:
215
+ continue
216
+ simplified_graph, graph_config = simplifier.simplify_graph(
217
+ phy_mesh,
218
+ num_stages,
219
+ num_devices_per_stage,
220
+ )
221
+ if simplified_graph is None:
222
+ continue
223
+ node_graph = NodeGraph(simplified_graph)
224
+ node_graph.assign_cost_weights(graph_config)
225
+ lmeshes = graph_config.stage_phy_meshes[
226
+ 0].get_logical_meshes()
227
+ for lmesh in lmeshes:
228
+ jobs.append(
229
+ (node_graph, graph_config, lmesh, memory_budget *
230
+ (num_devices / num_devices_per_stage)))
231
+
232
+ try:
233
+ with ThreadPoolExecutor() as executor:
234
+ best_config = sorted(
235
+ executor.map(
236
+ lambda x: find_solution(
237
+ *x,
238
+ flags,
239
+ torch.cuda.current_device(),
240
+ dump_path,
241
+ ),
242
+ jobs,
243
+ ),
244
+ key=lambda x: x.cost,
245
+ )[0]
246
+ finally:
247
+ phy_mesh.close()
248
+
249
+ if parallel_config_cache is not None:
250
+ best_config.network_hash = network_hash
251
+ best_config.auto_parallel_config = config
252
+ best_config.save(parallel_config_cache)
253
+
254
+ new_graphs = parallelize(simplifier, best_config)
255
+
256
+ networks = [to_network(new_graph, network) for new_graph in new_graphs]
257
+ if debug_mode and fill_weights:
258
+ networks[0]._fill_weights()
259
+
260
+ gc.collect()
261
+ torch.cuda.empty_cache()
262
+
263
+ return networks
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/cluster_info.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Tuple, Union
5
+
6
+ import pynvml
7
+ import torch
8
+ from cuda import cudart
9
+
10
+ from tensorrt_llm._utils import DictConversion
11
+ from tensorrt_llm.logger import logger
12
+ from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn
13
+
14
+
15
+ @dataclass
16
+ class MathThroughput(DictConversion):
17
+ int4: int = 0 # Tflops
18
+ int8: int = 0 # Tflops
19
+ fp8: int = 0 # Tflops
20
+ float16: int = 0 # Tflops
21
+ bfloat16: int = 0 # Tflops
22
+ float32: int = 0 # Tflops
23
+
24
+ @staticmethod
25
+ def to_tflops(
26
+ ipc_per_sm: "MathThroughput",
27
+ sm_count: int,
28
+ clock_mhz: int,
29
+ ) -> "MathThroughput":
30
+ tflops = MathThroughput()
31
+ for name in ipc_per_sm.__dataclass_fields__:
32
+ setattr(
33
+ tflops, name,
34
+ getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6))
35
+ return tflops
36
+
37
+
38
+ @dataclass
39
+ class ClusterInfo(DictConversion):
40
+ inter_node_bw_per_device: int = 25 # GBps
41
+ intra_node_bw_per_device: int = 0 # GBps
42
+ inter_node_latency: int = 10 # us
43
+ intra_node_latency: int = 10 # us
44
+ intra_node_sharp: bool = False
45
+ inter_node_sharp: bool = True
46
+
47
+ memory_bw: int = 0 # GBps
48
+ memory_budget_per_device: int = 0 # GB
49
+
50
+ math_throughput: MathThroughput = field(default_factory=MathThroughput)
51
+
52
+ memory_efficiency: float = 1.0
53
+ math_efficiency: float = 1.0
54
+ communication_efficiency: float = 1.0
55
+
56
+
57
+ _math_throughputs = {
58
+ "A100": MathThroughput(
59
+ int8=624,
60
+ float16=312,
61
+ bfloat16=312,
62
+ float32=156,
63
+ ),
64
+ }
65
+
66
+ _bandwidths = {
67
+ "PCIe-3": 16,
68
+ "PCIe-4": 32,
69
+ "PCIe-5": 64,
70
+ }
71
+
72
+ cluster_infos = {
73
+ # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
74
+ "A100-SXM-80GB":
75
+ ClusterInfo(
76
+ intra_node_bw_per_device=300,
77
+ memory_bw=2039,
78
+ memory_budget_per_device=80,
79
+ math_throughput=_math_throughputs["A100"],
80
+ ),
81
+ "A100-SXM-40GB":
82
+ ClusterInfo(
83
+ intra_node_bw_per_device=300,
84
+ memory_bw=1555,
85
+ memory_budget_per_device=40,
86
+ math_throughput=_math_throughputs["A100"],
87
+ ),
88
+ "A100-PCIe-80GB":
89
+ ClusterInfo(
90
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
91
+ memory_bw=1935,
92
+ memory_budget_per_device=80,
93
+ math_throughput=_math_throughputs["A100"],
94
+ ),
95
+ "A100-PCIe-40GB":
96
+ ClusterInfo(
97
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
98
+ memory_bw=1555,
99
+ memory_budget_per_device=40,
100
+ math_throughput=_math_throughputs["A100"],
101
+ ),
102
+ # from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
103
+ "H100-SXM":
104
+ ClusterInfo(
105
+ inter_node_bw_per_device=50,
106
+ intra_node_bw_per_device=450,
107
+ intra_node_sharp=True,
108
+ memory_bw=3350,
109
+ memory_budget_per_device=80,
110
+ math_throughput=MathThroughput(
111
+ int8=1979,
112
+ fp8=1979,
113
+ float16=989,
114
+ bfloat16=989,
115
+ float32=495,
116
+ ),
117
+ ),
118
+ "H100-PCIe":
119
+ ClusterInfo(
120
+ inter_node_bw_per_device=50,
121
+ intra_node_bw_per_device=_bandwidths["PCIe-5"],
122
+ memory_bw=2000,
123
+ memory_budget_per_device=80,
124
+ math_throughput=MathThroughput(
125
+ int8=1513,
126
+ fp8=1513,
127
+ float16=756,
128
+ bfloat16=756,
129
+ float32=378,
130
+ ),
131
+ ),
132
+ "H20":
133
+ ClusterInfo(
134
+ inter_node_bw_per_device=50,
135
+ intra_node_bw_per_device=450,
136
+ memory_bw=4000,
137
+ memory_budget_per_device=96,
138
+ math_throughput=MathThroughput(
139
+ int8=293,
140
+ fp8=293,
141
+ float16=147,
142
+ bfloat16=147,
143
+ float32=74,
144
+ ),
145
+ ),
146
+ # from https://images.nvidia.cn/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
147
+ "V100-PCIe-16GB":
148
+ ClusterInfo(
149
+ intra_node_bw_per_device=_bandwidths["PCIe-3"],
150
+ memory_bw=900,
151
+ memory_budget_per_device=16,
152
+ math_throughput=MathThroughput(float32=112),
153
+ ),
154
+ "V100-PCIe-32GB":
155
+ ClusterInfo(
156
+ intra_node_bw_per_device=_bandwidths["PCIe-3"],
157
+ memory_bw=900,
158
+ memory_budget_per_device=32,
159
+ math_throughput=MathThroughput(float32=112),
160
+ ),
161
+ "V100-SXM-16GB":
162
+ ClusterInfo(
163
+ intra_node_bw_per_device=150,
164
+ memory_bw=900,
165
+ memory_budget_per_device=16,
166
+ math_throughput=MathThroughput(float32=125),
167
+ ),
168
+ "V100-SXM-32GB":
169
+ ClusterInfo(
170
+ intra_node_bw_per_device=150,
171
+ memory_bw=900,
172
+ memory_budget_per_device=32,
173
+ math_throughput=MathThroughput(float32=125),
174
+ ),
175
+ "V100S-PCIe":
176
+ ClusterInfo(
177
+ intra_node_bw_per_device=_bandwidths["PCIe-3"],
178
+ memory_bw=1134,
179
+ memory_budget_per_device=32,
180
+ math_throughput=MathThroughput(float32=130),
181
+ ),
182
+ # from https://images.nvidia.cn/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
183
+ "A40":
184
+ ClusterInfo(
185
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
186
+ memory_bw=696,
187
+ memory_budget_per_device=48,
188
+ math_throughput=MathThroughput(
189
+ int4=600,
190
+ int8=300,
191
+ float16=150,
192
+ bfloat16=150,
193
+ float32=75,
194
+ ),
195
+ ),
196
+ # from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf
197
+ "A30":
198
+ ClusterInfo(
199
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
200
+ memory_bw=933,
201
+ memory_budget_per_device=24,
202
+ math_throughput=MathThroughput(
203
+ int4=661,
204
+ int8=330,
205
+ float16=165,
206
+ bfloat16=165,
207
+ float32=82,
208
+ ),
209
+ ),
210
+ # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf
211
+ "A10":
212
+ ClusterInfo(
213
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
214
+ memory_bw=600,
215
+ memory_budget_per_device=24,
216
+ math_throughput=MathThroughput(
217
+ int4=500,
218
+ int8=250,
219
+ float16=125,
220
+ bfloat16=125,
221
+ float32=62.5,
222
+ ),
223
+ ),
224
+ "A10G":
225
+ ClusterInfo(
226
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
227
+ memory_bw=600,
228
+ memory_budget_per_device=24,
229
+ math_throughput=MathThroughput(
230
+ int4=280,
231
+ int8=140,
232
+ float16=70,
233
+ bfloat16=70,
234
+ float32=35,
235
+ ),
236
+ ),
237
+ # from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413
238
+ "L40S":
239
+ ClusterInfo(
240
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
241
+ memory_bw=864,
242
+ memory_budget_per_device=48,
243
+ math_throughput=MathThroughput(
244
+ int4=733,
245
+ int8=733,
246
+ fp8=733,
247
+ float16=362,
248
+ bfloat16=362,
249
+ float32=183,
250
+ ),
251
+ ),
252
+ # from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf
253
+ "L40":
254
+ ClusterInfo(
255
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
256
+ memory_bw=864,
257
+ memory_budget_per_device=48,
258
+ math_throughput=MathThroughput(
259
+ int4=724,
260
+ int8=362,
261
+ fp8=362,
262
+ float16=181,
263
+ bfloat16=181,
264
+ float32=90,
265
+ ),
266
+ ),
267
+ "L20":
268
+ ClusterInfo(
269
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
270
+ memory_bw=864,
271
+ memory_budget_per_device=48,
272
+ math_throughput=MathThroughput(
273
+ int8=238,
274
+ fp8=238,
275
+ float16=119,
276
+ bfloat16=119,
277
+ float32=60,
278
+ ),
279
+ ),
280
+ # from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652
281
+ "L4":
282
+ ClusterInfo(
283
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
284
+ memory_bw=300,
285
+ memory_budget_per_device=24,
286
+ math_throughput=MathThroughput(
287
+ int8=242,
288
+ fp8=242,
289
+ float16=120,
290
+ bfloat16=120,
291
+ float32=60,
292
+ ),
293
+ ),
294
+ "L2":
295
+ ClusterInfo(
296
+ intra_node_bw_per_device=_bandwidths["PCIe-4"],
297
+ memory_bw=300,
298
+ memory_budget_per_device=24,
299
+ math_throughput=MathThroughput(
300
+ int8=193,
301
+ fp8=193,
302
+ float16=97,
303
+ bfloat16=97,
304
+ float32=48,
305
+ ),
306
+ ),
307
+ }
308
+
309
+
310
+ def infer_cluster_key() -> str:
311
+
312
+ def match(product, name):
313
+ # Use A100 as example, the regex pattern matches for:
314
+ # - NVIDIA A100 80GB
315
+ # - NVIDIA A100-PCIE
316
+ # - NVIDIA A100
317
+ # And does not match A1000 etc.
318
+ return re.match(f".*{product}([ -]|$).*", name) is not None
319
+
320
+ def is_sxm():
321
+ return "SXM" in device_name
322
+
323
+ def is_80gb():
324
+ return "80GB" in device_name
325
+
326
+ def is_32gb():
327
+ return "32GB" in device_name
328
+
329
+ device_name = torch.cuda.get_device_name(torch.cuda.current_device())
330
+
331
+ if match("A100", device_name):
332
+ if is_sxm():
333
+ if is_80gb():
334
+ return "A100-SXM-80GB"
335
+ else:
336
+ return "A100-SXM-40GB"
337
+ else:
338
+ if is_80gb():
339
+ return "A100-PCIe-80GB"
340
+ else:
341
+ return "A100-PCIe-40GB"
342
+ elif match("A10G", device_name):
343
+ return "A10G"
344
+ elif match("A10", device_name):
345
+ return "A10"
346
+ elif match("A30", device_name):
347
+ return "A30"
348
+ elif match("A40", device_name):
349
+ return "A40"
350
+ elif match("H100", device_name):
351
+ if is_sxm():
352
+ return "H100-SXM"
353
+ else:
354
+ return "H100-PCIe"
355
+ elif match("L40S", device_name):
356
+ return "L40S"
357
+ elif match("L40", device_name):
358
+ return "L40"
359
+ elif match("L4", device_name):
360
+ return "L4"
361
+ elif match("V100S", device_name):
362
+ return "V100S-PCIe"
363
+ elif match("V100", device_name):
364
+ if is_sxm():
365
+ if is_32gb():
366
+ return "V100-SXM-32GB"
367
+ else:
368
+ return "V100-SXM-16GB"
369
+ else:
370
+ if is_32gb():
371
+ return "V100-PCIe-32GB"
372
+ else:
373
+ return "V100-PCIe-16GB"
374
+ return None
375
+
376
+
377
+ def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput:
378
+ ipc_table = {
379
+ (9, 0):
380
+ MathThroughput(
381
+ int8=16384,
382
+ fp8=16384,
383
+ float16=8192,
384
+ bfloat16=8192,
385
+ float32=4096,
386
+ ),
387
+ (8, 0):
388
+ MathThroughput(
389
+ int4=8192,
390
+ int8=4096,
391
+ float16=2048,
392
+ bfloat16=2048,
393
+ float32=1024,
394
+ ),
395
+ (8, 6):
396
+ MathThroughput(
397
+ int4=4096,
398
+ int8=2048,
399
+ float16=1024,
400
+ bfloat16=1024,
401
+ float32=512,
402
+ ),
403
+ (8, 9):
404
+ MathThroughput(
405
+ int4=2048,
406
+ int8=1024,
407
+ fp8=1024,
408
+ float16=512,
409
+ bfloat16=512,
410
+ float32=256,
411
+ ),
412
+ (7, 0):
413
+ MathThroughput(
414
+ float16=1024,
415
+ float32=128,
416
+ ),
417
+ (7, 5):
418
+ MathThroughput(
419
+ int4=4096,
420
+ int8=2048,
421
+ float16=1024,
422
+ float32=128,
423
+ ),
424
+ }
425
+ return ipc_table.get(compute_cap, MathThroughput())
426
+
427
+
428
+ def nvlink_version(version_enum: int) -> int:
429
+ nvl_version_table = {
430
+ 1: 1,
431
+ 2: 2,
432
+ 3: 2,
433
+ 4: 2,
434
+ 5: 3,
435
+ 6: 3,
436
+ 7: 4,
437
+ }
438
+ return nvl_version_table[version_enum]
439
+
440
+
441
+ def nvlink_bandwidth(nvlink_version: int) -> int:
442
+ nvl_bw_table = {
443
+ 1: 80,
444
+ 2: 150,
445
+ 3: 300,
446
+ 4: 450,
447
+ }
448
+ return nvl_bw_table[nvlink_version]
449
+
450
+
451
+ def infer_cluster_info() -> ClusterInfo:
452
+ device = torch.cuda.current_device()
453
+ index = device.index if isinstance(device, torch.device) else device
454
+ with PyNVMLContext():
455
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
456
+ compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
457
+ logger.info(f"Compute capability: {compute_cap}")
458
+ err, properties = cudart.cudaGetDeviceProperties(index)
459
+ sm_count = properties.multiProcessorCount
460
+ logger.info(f"SM count: {sm_count}")
461
+ sm_clock = pynvml.nvmlDeviceGetMaxClockInfo(
462
+ handle,
463
+ pynvml.NVML_CLOCK_SM,
464
+ )
465
+ logger.info(f"SM clock: {sm_clock} MHz")
466
+ math_throughput = MathThroughput.to_tflops(
467
+ ipc_per_sm(compute_cap),
468
+ sm_count,
469
+ sm_clock,
470
+ )
471
+ for name in math_throughput.__dataclass_fields__:
472
+ tflops = getattr(math_throughput, name)
473
+ logger.info(f"{name} TFLOPS: {tflops}")
474
+
475
+ mem_info = _device_get_memory_info_fn(handle)
476
+ memory_budget = mem_info.total // (1024**3)
477
+ logger.info(f"Total Memory: {memory_budget} GiB")
478
+
479
+ mem_clock = pynvml.nvmlDeviceGetMaxClockInfo(
480
+ handle,
481
+ pynvml.NVML_CLOCK_MEM,
482
+ )
483
+ logger.info(f"Memory clock: {mem_clock} MHz")
484
+ if pynvml.__version__ < '11.5.0':
485
+ mem_bus_width = properties.memoryBusWidth
486
+ else:
487
+ mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
488
+ logger.info(f"Memory bus width: {mem_bus_width}")
489
+ memory_bw = mem_bus_width * mem_clock * 2 // int(8e3)
490
+ logger.info(f"Memory bandwidth: {memory_bw} GB/s")
491
+
492
+ try:
493
+ is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0))
494
+ logger.info(f"NVLink is active: {is_nvl_active}")
495
+ except pynvml.NVMLError:
496
+ is_nvl_active = False
497
+
498
+ intra_node_sharp = False
499
+ if is_nvl_active:
500
+ nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0)
501
+ nvl_version = nvlink_version(nvl_version_enum)
502
+ logger.info(f"NVLink version: {nvl_version}")
503
+ nvl_bw = nvlink_bandwidth(nvl_version)
504
+ logger.info(f"NVLink bandwidth: {nvl_bw} GB/s")
505
+ intra_node_bw = nvl_bw
506
+ if nvl_version >= 4:
507
+ intra_node_sharp = True
508
+ else:
509
+ if pynvml.__version__ < '11.5.0':
510
+ pcie_gen = pynvml.nvmlDeviceGetCurrPcieLinkGeneration(handle)
511
+ pcie_speed = (2**pcie_gen) * 1000
512
+ else:
513
+ pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle)
514
+ logger.info(f"PCIe speed: {pcie_speed} Mbps")
515
+ pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle)
516
+ logger.info(f"PCIe link width: {pcie_link_width}")
517
+ pcie_bw = pcie_speed * pcie_link_width // int(8e3)
518
+ logger.info(f"PCIe bandwidth: {pcie_bw} GB/s")
519
+ intra_node_bw = pcie_bw
520
+
521
+ cluster_info = ClusterInfo(
522
+ math_throughput=math_throughput,
523
+ memory_bw=memory_bw,
524
+ memory_budget_per_device=memory_budget,
525
+ intra_node_bw_per_device=intra_node_bw,
526
+ intra_node_sharp=intra_node_sharp,
527
+ )
528
+ return cluster_info
529
+
530
+
531
+ def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]:
532
+ device_name = torch.cuda.get_device_name(torch.cuda.current_device())
533
+ cluster_key = infer_cluster_key()
534
+ if cluster_key is not None:
535
+ return dict(cluster_key=cluster_key)
536
+ else:
537
+ try:
538
+ cluster_info = infer_cluster_info()
539
+ except pynvml.NVMLError:
540
+ fallback_cluster_key = "L40"
541
+ cluster_info = copy.copy(cluster_infos[fallback_cluster_key])
542
+ memory_budget = torch.cuda.mem_get_info()[1] // (1024**3)
543
+ cluster_info.memory_budget_per_device = memory_budget
544
+ logger.warning(
545
+ f"Failed to infer cluster info for {device_name}, "
546
+ f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. "
547
+ "This setting makes no effect if you do not use auto parallel.")
548
+ return dict(
549
+ cluster_key=device_name.replace(" ", "-"),
550
+ cluster_info=cluster_info,
551
+ )
552
+
553
+
554
+ if __name__ == "__main__":
555
+ logger.set_level("info")
556
+ infer_cluster_info()
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from enum import auto
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from strenum import LowercaseStrEnum
6
+
7
+ from tensorrt_llm._utils import BaseEnumMeta, DictConversion
8
+
9
+ from .cluster_info import ClusterInfo, cluster_infos
10
+
11
+
12
+ class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta):
13
+ ALPHA_BETA = auto()
14
+ PROFILE = auto()
15
+ S_CURVE = auto()
16
+ # Zero cost model is for test purpose.
17
+ # Use zero cost model for communication will make solver prefer sharding
18
+ # Use zero cost model for computation will make solver prefer replication
19
+ ZERO = auto()
20
+
21
+
22
+ @dataclass
23
+ class AutoParallelConfig(DictConversion):
24
+ # cluster configuration
25
+ world_size: int = 1
26
+ gpus_per_node: int = 8
27
+ cluster_key: str = None
28
+ cluster_info: Optional[ClusterInfo] = None
29
+
30
+ # cost model configuration
31
+ sharding_cost_model: str = CostModel.ALPHA_BETA
32
+ comm_cost_model: str = CostModel.ALPHA_BETA
33
+
34
+ # strategy configuration
35
+ enable_pipeline_parallelism: bool = False
36
+ enable_shard_unbalanced_shape: bool = False
37
+ enable_shard_dynamic_shape: bool = False
38
+ enable_reduce_scatter: bool = True
39
+
40
+ # parallelization configuration
41
+ builder_flags: Optional[int] = None
42
+ debug_mode: bool = False
43
+ infer_shape: bool = True
44
+ validation_mode: bool = False
45
+ same_buffer_io: Dict[str, str] = field(default_factory=dict)
46
+ same_spec_io: Dict[str, str] = field(default_factory=dict)
47
+ sharded_io_allowlist: List[str] = field(default_factory=list)
48
+ fill_weights: bool = False
49
+
50
+ # debug configuration
51
+ parallel_config_cache: Optional[str] = None
52
+ profile_cache: Optional[str] = None
53
+ dump_path: Optional[str] = None
54
+ debug_outputs: Union[List[str], str] = field(default_factory=list)
55
+
56
+ def get_cluster_info(self) -> ClusterInfo:
57
+ return self.cluster_info or cluster_infos[self.cluster_key]
58
+
59
+ @property
60
+ def enabled(self) -> bool:
61
+ return self.world_size > 1
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/device_mesh.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from typing import List
5
+
6
+ import h5py
7
+ import numpy as np
8
+ from filelock import FileLock
9
+
10
+ from .config import AutoParallelConfig, CostModel
11
+ from .tensor_parallel.shape_consistency import ShapeConsistencyManager
12
+
13
+
14
+ class ProfileDB(ABC):
15
+ """A database that stores profiling results for multiple device mesh
16
+ shapes."""
17
+
18
+ @abstractmethod
19
+ def query(self, cluster_key, data_key):
20
+ ...
21
+
22
+ @abstractmethod
23
+ def update(self, cluster_key, data_key, mesh_result):
24
+ ...
25
+
26
+ def close(self):
27
+ pass
28
+
29
+
30
+ class MemDB(ProfileDB):
31
+
32
+ def __init__(self):
33
+ self.data = {}
34
+
35
+ def query(self, cluster_key, data_key):
36
+ key = (cluster_key, data_key)
37
+ mesh_result = self.data.get(key, None)
38
+ if mesh_result is None:
39
+ return None
40
+ else:
41
+ return mesh_result[0]
42
+
43
+ def update(self, cluster_key, data_key, mesh_result):
44
+ key = (cluster_key, data_key)
45
+ self.data[key] = mesh_result
46
+
47
+
48
+ class Hdf5DB(ProfileDB):
49
+
50
+ def __init__(self, name):
51
+ self.name = name
52
+ lock_name = self.name + ".lock"
53
+ self.lock = FileLock(lock_name, thread_local=False)
54
+
55
+ def query(self, cluster_key, data_key):
56
+ file_name = f"{self.name}.hdf5"
57
+ key = str((cluster_key, data_key))
58
+ self.lock.acquire()
59
+ mesh_result = None
60
+ with h5py.File(file_name, 'a') as f:
61
+ if key in f:
62
+ self.lock.release()
63
+ mesh_result = f[key]
64
+ return mesh_result[0]
65
+ else:
66
+ return None
67
+
68
+ def update(self, cluster_key, data_key, mesh_result):
69
+ key = str((cluster_key, data_key))
70
+ file_name = f"{self.name}.hdf5"
71
+ with h5py.File(file_name, 'a') as f:
72
+ f[key] = mesh_result
73
+
74
+ def close(self):
75
+ self.lock.release(force=True)
76
+
77
+
78
+ class LogicalDeviceMesh(object):
79
+
80
+ def __init__(self,
81
+ phy_mesh_shape,
82
+ mesh_shape,
83
+ phy_ids,
84
+ config: AutoParallelConfig,
85
+ alpha,
86
+ beta,
87
+ sharp,
88
+ prof_database=None,
89
+ shape_consistency_manager=None,
90
+ host_ips=None):
91
+ self.phy_mesh_shape = phy_mesh_shape
92
+ self.mesh_shape = mesh_shape
93
+ self.phy_ids = phy_ids
94
+ self.host_ips = host_ips
95
+ self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join(
96
+ [str(i) for i in mesh_shape]))
97
+ self.prof_min_max_size = [1, 2**34]
98
+ self.prof_comm_dtypes = [
99
+ "int8", "uint8", "int32", "uint32", "int64", "uint64", "float16",
100
+ "float32", "float64", "bfloat16"
101
+ ]
102
+ self.devices_group = {
103
+ (0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1],
104
+ (1, ): [self.phy_ids, self.mesh_shape[1]],
105
+ (0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0]
106
+ }
107
+ self.prof_database = prof_database
108
+ self.shape_consistency_manager = shape_consistency_manager
109
+ self.config = config
110
+ self.cluster_info = config.get_cluster_info()
111
+ self.hw_alpha = alpha
112
+ self.hw_beta = beta
113
+ self.hw_sharp = sharp
114
+ self.algo_alpha_beta = self._estimate_algo_alpha_beta()
115
+ self.comm_op_to_nccl_test_func_name = {
116
+ 'all_reduce': 'all_reduce_perf_mpi',
117
+ 'all_gather': 'all_gather_perf_mpi',
118
+ 'all_to_all': 'alltoall_perf_mpi',
119
+ 'reduce_scatter': 'reduce_scatter_perf_mpi',
120
+ 'split': 'split',
121
+ }
122
+
123
+ @property
124
+ def size(self) -> int:
125
+ return self.phy_ids.size
126
+
127
+ def _estimate_algo_alpha_beta(self):
128
+ ret = {}
129
+ ar_alpha, ar_beta = {}, {}
130
+ ag_alpha, ag_beta = {}, {}
131
+ rs_alpha, rs_beta = {}, {}
132
+ a2a_alpha, a2a_beta = {}, {}
133
+ phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape
134
+ if phy_num_hosts == 1 or phy_num_devices_per_host == 1:
135
+ for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
136
+ num_devices = 1
137
+ for dim in dims:
138
+ num_devices = self.mesh_shape[dim] * num_devices
139
+ if num_devices != 1:
140
+ ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[
141
+ 0] else self.hw_alpha[0] * num_devices / 2 / (
142
+ num_devices - 1)
143
+ ar_beta[dims] = self.hw_beta[0]
144
+ ag_alpha[dims] = self.hw_alpha[0] * num_devices / (
145
+ num_devices - 1)
146
+ ag_beta[dims] = self.hw_beta[0]
147
+ rs_alpha[dims] = self.hw_alpha[0] * num_devices / (
148
+ num_devices - 1)
149
+ rs_beta[dims] = self.hw_beta[0]
150
+ a2a_alpha[dims] = self.hw_alpha[0] * num_devices / (
151
+ num_devices - 1)
152
+ a2a_beta[dims] = self.hw_beta[0]
153
+ # phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1
154
+ else:
155
+ for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
156
+ num_devices = 1
157
+ for dim in dims:
158
+ num_devices = self.mesh_shape[dim] * num_devices
159
+ if num_devices != 1:
160
+ if len(dims) == 1:
161
+ dim = dims[0]
162
+ ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[
163
+ dim] else self.hw_alpha[dim] * num_devices / 2 / (
164
+ num_devices - 1)
165
+ ar_beta[dims] = self.hw_beta[dim]
166
+ ag_alpha[dims] = self.hw_alpha[dim] * num_devices / (
167
+ num_devices - 1)
168
+ ag_beta[dims] = self.hw_beta[dim]
169
+ rs_alpha[dims] = self.hw_alpha[dim] * num_devices / (
170
+ num_devices - 1)
171
+ rs_beta[dims] = self.hw_beta[dim]
172
+ a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / (
173
+ num_devices - 1)
174
+ a2a_beta[dims] = self.hw_beta[dim]
175
+ elif len(dims) == 2: # two level communication
176
+ num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host
177
+ inter_node_col_alpha = self.hw_alpha[
178
+ 0] * num_devices_per_host
179
+ inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[
180
+ 0] else inter_node_col_alpha * num_hosts / 2 / (
181
+ num_hosts - 1)
182
+ intra_node_ar_alpha = self.hw_alpha[1]
183
+ intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[
184
+ 1] else intra_node_ar_alpha * num_devices_per_host / 2 / (
185
+ num_devices_per_host - 1)
186
+ ar_alpha[dims] = min(inter_node_ar_alpha,
187
+ intra_node_ar_alpha)
188
+ ar_beta[dims] = max(self.hw_beta)
189
+ ag_alpha[dims] = min(
190
+ inter_node_col_alpha * num_hosts / (num_hosts - 1),
191
+ self.hw_alpha[1] * num_devices_per_host /
192
+ (num_devices_per_host - 1))
193
+ ag_beta[dims] = max(self.hw_beta)
194
+ rs_alpha[dims] = ag_alpha[dims]
195
+ rs_beta[dims] = ag_beta[dims]
196
+ a2a_alpha[dims] = min(
197
+ num_hosts * self.hw_alpha[0] / (num_hosts - 1),
198
+ self.hw_alpha[1] * num_hosts)
199
+ a2a_beta[dims] = max(self.hw_beta)
200
+ else:
201
+ pass
202
+ ret['all_to_all'] = [a2a_alpha, a2a_beta]
203
+ ret['all_reduce'] = [ar_alpha, ar_beta]
204
+ ret['all_gather'] = [ag_alpha, ag_beta]
205
+ ret['reduce_scatter'] = [rs_alpha, rs_beta]
206
+ ret['p2p_cross_device'] = [
207
+ self.cluster_info.intra_node_bw_per_device,
208
+ self.cluster_info.intra_node_latency
209
+ ]
210
+ ret['p2p_cross_host'] = [
211
+ self.cluster_info.inter_node_bw_per_device,
212
+ self.cluster_info.inter_node_latency
213
+ ]
214
+ return ret
215
+
216
+ #[ToDo][KDuan] stub functions here
217
+ def _profile_split(self, min_max_comm_size):
218
+ comm_size, elapsed_time = [], []
219
+ size = min_max_comm_size[0]
220
+ while size <= min_max_comm_size[1]:
221
+ time = size * 2 / self.cluster_info.memory_bw
222
+ comm_size.append(size)
223
+ elapsed_time.append(time)
224
+ size = size * 2
225
+ return np.array([comm_size, elapsed_time])
226
+
227
+ def _prase_nccl_test_results(self, f_nccl_test_out_log):
228
+ '''[ToDo][KDuan] There is some dtye that may not been supported by nccl test, using default dtype (float)'''
229
+ start_parse = False
230
+ comm_size, elapsed_time = [], []
231
+ try:
232
+ with open(f_nccl_test_out_log, 'r') as lines:
233
+ for line in lines:
234
+ if start_parse:
235
+ prof_data = re.split(r"[ ]+", line.strip())
236
+ if len(prof_data) != 13:
237
+ continue
238
+ comm_size.append(float(prof_data[0]))
239
+ elapsed_time.append(float(prof_data[5]))
240
+ if 'GB/s' in line and 'us' in line:
241
+ start_parse = True
242
+ except Exception:
243
+ print(f'failed to parse {f_nccl_test_out_log}')
244
+ return comm_size, elapsed_time
245
+
246
+ def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group,
247
+ func_name, step, workload_key):
248
+
249
+ if func_name == 'split':
250
+ if 2 == step:
251
+ return self._profile_split(min_max_comm_size)
252
+ else:
253
+ return None
254
+ workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}'
255
+ os.makedirs(workspace_dir, exist_ok=True)
256
+ outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err'
257
+ if 1 == step:
258
+ num_nodes = len(self.host_ips)
259
+ num_gpus = self.mesh_shape[0] * self.mesh_shape[1]
260
+ ntasks_per_node = num_gpus // num_nodes
261
+ nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format(
262
+ device_group[1], func_name, min_max_comm_size[0],
263
+ min_max_comm_size[1], dtype, 2)
264
+ sbatch_command = '#!/bin/bash\n'
265
+ sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition'])
266
+ sbatch_command += '#SBATCH -A {}\n'.format(self.config['account'])
267
+ sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname'])
268
+ sbatch_command += '#SBATCH -N {}\n'.format(num_nodes)
269
+ sbatch_command += '#SBATCH -t {}\n'.format(self.config['time'])
270
+ sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format(
271
+ ntasks_per_node)
272
+ sbatch_command += '#SBATCH --exclusive\n'
273
+ sbatch_command += '#SBATCH --mem=0\n'
274
+ sbatch_command += '#SBATCH --network=sharp\n'
275
+ sbatch_command += '#SBATCH --mail-type=FAIL\n'
276
+ srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format(
277
+ num_nodes, ntasks_per_node, outfile, errfile,
278
+ self.config['container'])
279
+ command = sbatch_command + srun_command + nccl_test_command
280
+ with open(workspace_dir + '/workload.sub', 'w') as f:
281
+ f.write(command)
282
+ with open('./preprofiling_step1.sh', 'a') as f:
283
+ f.write(f'sbatch {workspace_dir}/workload.sub\n')
284
+ return None
285
+
286
+ else:
287
+ comm_size, elapsed_time = self._prase_nccl_test_results(outfile)
288
+ if len(comm_size) < 2:
289
+ assert 0, 'the profiling for {} was failed at step1, please try again'.format(
290
+ workload_key)
291
+ else:
292
+ print(workload_key, comm_size, elapsed_time)
293
+ return np.array([comm_size, elapsed_time])
294
+
295
+ def _profile_single_comm_perf(self, device_group, comm_op, step, data_key):
296
+ results = {}
297
+ func_name = self.comm_op_to_nccl_test_func_name[comm_op]
298
+ for dtype in self.prof_comm_dtypes:
299
+ size_time = self._profile_with_nccl_test(
300
+ self.prof_min_max_size, dtype, device_group, func_name, step,
301
+ data_key + f'_dtype{dtype}')
302
+ results[dtype] = size_time
303
+ return results
304
+
305
+ def profile_all_comms_perf(self, step):
306
+ if self.mesh_shape == (1, 1):
307
+ return None
308
+ mesh_results = self.prof_database.query(self.cluster_key,
309
+ self.mesh_shape)
310
+ if mesh_results:
311
+ return mesh_results
312
+
313
+ mesh_results = {}
314
+ data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}'
315
+ for comm_op in [
316
+ 'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter',
317
+ 'split'
318
+ ]:
319
+ comm_perf = {}
320
+ for dim, device_group in self.devices_group.items():
321
+ # don't need to profile for mesh dim == 1
322
+ if len(dim) == 1 and self.mesh_shape[dim[0]] == 1:
323
+ continue
324
+
325
+ comm_perf[dim] = self._profile_single_comm_perf(
326
+ device_group, comm_op, step, data_key +
327
+ '_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim))))
328
+ mesh_results[comm_op] = comm_perf
329
+ if 2 == step:
330
+ self.prof_database.update(self.cluster_key, self.mesh_shape,
331
+ mesh_results)
332
+
333
+ return mesh_results
334
+
335
+ def _model_comm_cost_from_s_curve(self, size_time_array, realsize):
336
+ assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\
337
+ 'the comm_size: {} is not in the profile range: [{}{}]'\
338
+ .format(realsize, size_time_array[0][0], size_time_array[0][-1])
339
+ return np.interp(realsize, size_time_array[0], size_time_array[1])
340
+
341
+ def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes):
342
+ elapsed_time = 0.0
343
+ if 'split' == comm_op:
344
+ elapsed_time = size_in_bytes * 2 / (
345
+ self.cluster_info.memory_bw *
346
+ self.cluster_info.memory_efficiency) * 1e-3
347
+ else:
348
+ dict_alpha, dict_beta = self.algo_alpha_beta[comm_op]
349
+ alpha, beta = dict_alpha[dim_key], dict_beta[dim_key]
350
+ elapsed_time = (size_in_bytes /
351
+ (alpha * self.cluster_info.communication_efficiency)
352
+ * 1e-3) + beta
353
+ return elapsed_time
354
+
355
+ def _input_size_to_comm_size(self, comm_op, dims, input_size):
356
+ ret = input_size
357
+ if 'all_gather' == comm_op:
358
+ for dim in dims:
359
+ ret = ret * self.mesh_shape[dim]
360
+ return ret
361
+
362
+ def estimate_comm_cost(self, comm_op, dim, input_size, dtype):
363
+
364
+ size = self._input_size_to_comm_size(comm_op, dim, input_size)
365
+ if self.config.comm_cost_model == CostModel.S_CURVE:
366
+ mesh_perf = self.prof_database.query(self.cluster_key,
367
+ self.mesh_shape)
368
+ assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format(
369
+ self.mesh_shape)
370
+ comm_op_perf = mesh_perf.get(comm_op, None)
371
+ assert comm_op_perf is not None, '{} is not profiled'.format(
372
+ comm_op)
373
+ elapsed_time = self._model_comm_cost_from_s_curve(
374
+ comm_op_perf[tuple(dim)][dtype], size)
375
+ return elapsed_time
376
+ elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
377
+ elapsed_time = self._model_comm_cost_from_alpha_beta(
378
+ comm_op, tuple(dim), size)
379
+ elif self.config.comm_cost_model == CostModel.PROFILE:
380
+ assert False, 'Unsupported profile based communication cost model now'
381
+ elif self.config.comm_cost_model == CostModel.ZERO:
382
+ elapsed_time = 0.0
383
+
384
+ return elapsed_time # us
385
+
386
+
387
+ class PhysicalDeviceMesh(object):
388
+
389
+ def __init__(self,
390
+ phy_devices_id,
391
+ config: AutoParallelConfig,
392
+ prof_database=None,
393
+ shape_consistency_manager=None,
394
+ host_ips=None):
395
+ self.phy_devices_id = np.array(phy_devices_id)
396
+ self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape
397
+ self.host_ips = host_ips
398
+ if host_ips is None:
399
+ self.host_ips = [''] * self.num_hosts
400
+ self.config = config
401
+ self.cluster_info = config.get_cluster_info()
402
+ self.prof_database: ProfileDB = prof_database
403
+ self.shape_consistency_manager = shape_consistency_manager
404
+ if self.config.comm_cost_model not in CostModel:
405
+ raise ValueError(
406
+ f'unsupported communication cost model: {self.config.comm_cost_model}'
407
+ )
408
+ if self.config.sharding_cost_model not in CostModel:
409
+ raise ValueError(
410
+ f'unsupported sharding cost model: {self.config.sharding_cost_model}'
411
+ )
412
+ if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE:
413
+ if self.prof_database is None:
414
+ profile_cache = config.profile_cache
415
+ if profile_cache is None:
416
+ self.prof_database = MemDB()
417
+ else:
418
+ self.prof_database = Hdf5DB(profile_cache)
419
+ elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
420
+ assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method'
421
+ assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method'
422
+ if self.config.sharding_cost_model == CostModel.ALPHA_BETA:
423
+ assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method'
424
+
425
+ if not shape_consistency_manager:
426
+ self.shape_consistency_manager = ShapeConsistencyManager()
427
+
428
+ @property
429
+ def size(self) -> int:
430
+ return self.phy_devices_id.size
431
+
432
+ def close(self):
433
+ if self.prof_database is not None:
434
+ self.prof_database.close()
435
+
436
+ def split_pipeline_meshes(
437
+ self, num_stages,
438
+ num_devices_per_stage) -> List["PhysicalDeviceMesh"]:
439
+ sub_meshes = []
440
+ if num_devices_per_stage <= self.num_devices_per_host:
441
+ assert self.num_devices_per_host % num_devices_per_stage == 0, \
442
+ "num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\
443
+ .format(self.num_devices_per_host, num_devices_per_stage)
444
+ num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage
445
+ num_clusters = self.num_hosts * num_clusters_per_host
446
+ assert num_stages % num_clusters == 0, \
447
+ "num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters)
448
+ for mesh_id in range(num_stages):
449
+ cluster_id = mesh_id % num_clusters
450
+ cluster_col = cluster_id % num_clusters_per_host
451
+ cluster_row = cluster_id // num_clusters_per_host
452
+ sub_devices_id = [
453
+ self.phy_devices_id[cluster_row][cluster_col *
454
+ num_devices_per_stage:(
455
+ (cluster_col + 1) *
456
+ num_devices_per_stage)]
457
+ ]
458
+ sub_meshes.append(
459
+ PhysicalDeviceMesh(sub_devices_id, self.config,
460
+ self.prof_database,
461
+ self.shape_consistency_manager,
462
+ [self.host_ips[cluster_row]]))
463
+ else:
464
+ assert num_devices_per_stage % self.num_devices_per_host == 0, \
465
+ "num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\
466
+ .format(num_devices_per_stage, self.num_devices_per_host)
467
+ num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host
468
+ assert self.num_hosts % num_host_per_cluster == 0, \
469
+ "num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster)
470
+ num_clusters = self.num_hosts // num_host_per_cluster
471
+ for mesh_id in range(num_stages):
472
+ cluster_id = mesh_id % num_clusters
473
+ cluster_row = cluster_id * num_host_per_cluster
474
+ sub_devices_id = self.phy_devices_id[cluster_row:(
475
+ cluster_row + num_host_per_cluster)]
476
+ host_ips = self.host_ips[cluster_row:(cluster_row +
477
+ num_host_per_cluster)]
478
+ sub_meshes.append(
479
+ PhysicalDeviceMesh(sub_devices_id, self.config,
480
+ self.prof_database,
481
+ self.shape_consistency_manager,
482
+ host_ips))
483
+ return sub_meshes
484
+
485
+ def _profile_logical_meshes(self, logical_meshes, step):
486
+ for lmesh in logical_meshes:
487
+ lmesh.profile_all_comms_perf(step)
488
+
489
+ def as_logical_mesh(self) -> LogicalDeviceMesh:
490
+ alpha = [
491
+ self.cluster_info.inter_node_bw_per_device,
492
+ self.cluster_info.intra_node_bw_per_device
493
+ ]
494
+ beta = [
495
+ self.cluster_info.inter_node_latency,
496
+ self.cluster_info.intra_node_latency
497
+ ]
498
+ sharp = [
499
+ self.cluster_info.inter_node_sharp,
500
+ self.cluster_info.intra_node_sharp
501
+ ]
502
+ return LogicalDeviceMesh(
503
+ self.phy_devices_id.shape,
504
+ self.phy_devices_id.shape,
505
+ self.phy_devices_id,
506
+ self.config,
507
+ alpha,
508
+ beta,
509
+ sharp,
510
+ self.prof_database,
511
+ self.shape_consistency_manager,
512
+ self.host_ips,
513
+ )
514
+
515
+ def get_logical_meshes(self):
516
+ logical_meshes = []
517
+ # (1, 2) -> (1, 2)
518
+ # (1, 4) -> (2, 2)
519
+ # (1, 8) -> (2, 4)
520
+ # (1, 16) -> (2, 8), (4, 4)
521
+ # (1, 32) -> (2, 16), (4, 8)
522
+ # (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8)
523
+ # (1, 64) -> (2, 32), (4, 16), (8, 8)
524
+ # we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2)
525
+ # we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1)
526
+ if self.num_hosts == 1:
527
+ alpha = [self.cluster_info.intra_node_bw_per_device]
528
+ beta = [self.cluster_info.intra_node_latency]
529
+ sharp = [self.cluster_info.intra_node_sharp]
530
+ for i in range(2, self.num_devices_per_host):
531
+ if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host:
532
+ lmesh_shape = (i, self.num_devices_per_host // i)
533
+ lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
534
+ logical_meshes.append(
535
+ LogicalDeviceMesh(self.phy_devices_id.shape,
536
+ lmesh_shape, lmesh_phy_ids,
537
+ self.config, alpha, beta, sharp,
538
+ self.prof_database,
539
+ self.shape_consistency_manager,
540
+ self.host_ips))
541
+ # (8, 1) -> (2, 4)
542
+ # (16, 1) -> (2, 8), (4, 4)
543
+ elif self.num_devices_per_host == 1:
544
+ alpha = [self.cluster_info.inter_node_bw_per_device]
545
+ beta = [self.cluster_info.inter_node_latency]
546
+ sharp = [self.cluster_info.inter_node_sharp]
547
+ for i in range(2, self.num_hosts):
548
+ if self.num_hosts % i == 0 and i * i <= self.num_hosts:
549
+ lmesh_shape = (i, self.num_hosts // i)
550
+ lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
551
+ logical_meshes.append(
552
+ LogicalDeviceMesh(self.phy_devices_id.shape,
553
+ lmesh_phy_ids, self.config, alpha,
554
+ beta, sharp, self.prof_database,
555
+ self.shape_consistency_manager,
556
+ self.host_ips))
557
+ # (2, 1) -> (2, 1)
558
+ # (2, 8) -> (2, 8)
559
+ # (1, 2) -> (1, 2)
560
+ # (1, 3) -> (1, 3)
561
+ # (1, 5) -> (1, 5)
562
+ if 0 == len(logical_meshes):
563
+ logical_meshes.append(self.as_logical_mesh())
564
+ return logical_meshes
565
+
566
+ '''
567
+ we assume we can evenly split the pipeline and deviceMesh
568
+ '''
569
+
570
+ def _list_all_sub_meshes(self):
571
+ sub_meshes = []
572
+ for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
573
+ if self.num_devices_per_host % num_devices_per_stage == 0:
574
+ num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
575
+ sub_meshes.append(
576
+ self.split_pipeline_meshes(num_stages,
577
+ num_devices_per_stage)[0])
578
+ for num_hosts_per_stage in range(2, self.num_hosts + 1):
579
+ if self.num_hosts % num_hosts_per_stage == 0:
580
+ num_stages = self.num_hosts // num_hosts_per_stage
581
+ sub_meshes.append(
582
+ self.split_pipeline_meshes(
583
+ num_stages,
584
+ num_hosts_per_stage * self.num_devices_per_host)[0])
585
+ return sub_meshes
586
+
587
+ def list_all_pipeline_configs(self):
588
+ configs = []
589
+ for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
590
+ if self.num_devices_per_host % num_devices_per_stage == 0:
591
+ num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
592
+ configs.append((num_stages, num_devices_per_stage))
593
+ for num_hosts_per_stage in range(2, self.num_hosts + 1):
594
+ if self.num_hosts % num_hosts_per_stage == 0:
595
+ num_stages = self.num_hosts // num_hosts_per_stage
596
+ configs.append(
597
+ (num_stages,
598
+ num_hosts_per_stage * self.num_devices_per_host))
599
+ return configs
600
+
601
+ def profile_s_curve(self, step):
602
+ sub_phy_device_meshes = self._list_all_sub_meshes()
603
+ for phy_mesh in sub_phy_device_meshes:
604
+ lmeshes = phy_mesh.get_logical_meshes()
605
+ self._profile_logical_meshes(lmeshes, step)
606
+ if 2 == step:
607
+ self.save_profile_database()
608
+
609
+ def profile_alpha_beta(self):
610
+ alpha = [250, 25]
611
+ beta = [100, 100]
612
+ return alpha, beta
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/node_graph.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+ import tensorrt as trt
5
+
6
+ from .pipeline_graph import PipelineGraph
7
+ from .runtime_profiling import RuntimeProfiler
8
+ from .simplifier import GraphConfig, StageType
9
+ from .solver import CostGraph, Solver
10
+ from .tensor_parallel.activation_node import Activation
11
+ from .tensor_parallel.assertion_node import Assertion
12
+ from .tensor_parallel.cast_node import Cast
13
+ from .tensor_parallel.concatenation_node import Concatenation
14
+ from .tensor_parallel.constant_node import Constant
15
+ from .tensor_parallel.elementwise_node import ElementWise
16
+ from .tensor_parallel.fill_node import Fill
17
+ from .tensor_parallel.gather_node import Gather
18
+ from .tensor_parallel.identity_node import Identity
19
+ from .tensor_parallel.input_node import InputNode
20
+ from .tensor_parallel.matmul_node import MatrixMultiply
21
+ from .tensor_parallel.node import Node
22
+ from .tensor_parallel.normalization_node import Normalization
23
+ from .tensor_parallel.output_node import OuputNode
24
+ from .tensor_parallel.p2p_node import P2PNode, P2PType
25
+ from .tensor_parallel.plugin_node import PluginNode
26
+ from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
27
+ from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
28
+ from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
29
+ from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
30
+ from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
31
+ RMSnormPlugin)
32
+ from .tensor_parallel.reduce_node import Reduce
33
+ from .tensor_parallel.select_node import Select
34
+ from .tensor_parallel.shape_node import Shape
35
+ from .tensor_parallel.shuffle_node import Shuffle
36
+ from .tensor_parallel.slice_node import Slice
37
+ from .tensor_parallel.softmax_node import SoftMax
38
+ from .tensor_parallel.unary_node import Unary
39
+
40
+ LAYER_TYPE_2_NODE_TYPE = {
41
+ trt.LayerType.ACTIVATION: Activation,
42
+ trt.LayerType.ASSERTION: Assertion,
43
+ trt.LayerType.CAST: Cast,
44
+ trt.LayerType.CONCATENATION: Concatenation,
45
+ trt.LayerType.CONSTANT: Constant,
46
+ trt.LayerType.ELEMENTWISE: ElementWise,
47
+ trt.LayerType.FILL: Fill,
48
+ trt.LayerType.GATHER: Gather,
49
+ trt.LayerType.IDENTITY: Identity,
50
+ trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
51
+ trt.LayerType.NORMALIZATION: Normalization,
52
+ trt.LayerType.PLUGIN_V2: PluginNode,
53
+ trt.LayerType.REDUCE: Reduce,
54
+ trt.LayerType.SELECT: Select,
55
+ trt.LayerType.SHAPE: Shape,
56
+ trt.LayerType.SHUFFLE: Shuffle,
57
+ trt.LayerType.SLICE: Slice,
58
+ trt.LayerType.SOFTMAX: SoftMax,
59
+ trt.LayerType.UNARY: Unary,
60
+ }
61
+ # TODO: BertAttention/All Quant plugins
62
+ PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
63
+ 'GPTAttention': GPTAttentionPlugin,
64
+ 'Gemm': GemmPlugin,
65
+ 'Layernorm': LayernormPlugin,
66
+ 'Rmsnorm': RMSnormPlugin,
67
+ 'Lookup': LookupPlugin,
68
+ 'Identity': IdentityPlugin,
69
+ }
70
+
71
+
72
+ class NodeGraph:
73
+
74
+ def __init__(self, graph: PipelineGraph):
75
+ self._nodes = {}
76
+
77
+ # construct nodes
78
+ for input in graph.inputs:
79
+ self._nodes[input.name] = InputNode(input)
80
+ for layer in graph.layers:
81
+ layer.to_base_class()
82
+ if "p2p_type" in layer.attrs:
83
+ self._nodes[layer.name] = P2PNode(layer)
84
+ elif layer.type == trt.LayerType.PLUGIN_V2:
85
+ layer.to_subclass()
86
+ plugin_type = layer.as_trt().plugin.plugin_type
87
+ layer.to_base_class()
88
+ if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
89
+ node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
90
+ else:
91
+ node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
92
+ self._nodes[layer.name] = node
93
+ else:
94
+ node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
95
+ self._nodes[layer.name] = node
96
+ for output in graph.outputs:
97
+ self._nodes[output.name] = OuputNode(output)
98
+ for node in self.nodes:
99
+ node.post_init(self)
100
+ node.node_runtime_profiler = RuntimeProfiler()
101
+
102
+ def get_node(self, name):
103
+ return self._nodes[name]
104
+
105
+ @property
106
+ def nodes(self) -> List[Node]:
107
+ return [*self._nodes.values()]
108
+
109
+ def assign_cost_weights(self, graph_config: GraphConfig):
110
+ layer_mapping = graph_config.graph_mapping.layer_mapping
111
+ for layer_name in layer_mapping.values():
112
+ node = self.get_node(layer_name)
113
+ node.sharding_weight += 1
114
+ node.resharding_weight += 1
115
+ same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
116
+ for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
117
+ node = self.get_node(layer_name)
118
+ same_spec_node = self.get_node(same_spec_layer_name)
119
+ same_spec_node.sharding_weight = node.sharding_weight
120
+ same_spec_node.resharding_weight = node.resharding_weight
121
+
122
+ def set_slowest_stage(self, stage_type: StageType,
123
+ graph_config: GraphConfig):
124
+ num_micro_batches = graph_config.num_micro_batches
125
+ block_per_stage = graph_config.num_blocks // graph_config.num_stages
126
+ block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
127
+ for node in self.nodes:
128
+ node.pipeline_weight = 0
129
+ node.cost_level = -1
130
+ if node.stage_type == StageType.START:
131
+ if stage_type == StageType.START:
132
+ node.pipeline_weight = num_micro_batches - 1
133
+ node.cost_level = 1
134
+ else:
135
+ node.cost_level = 0
136
+ if stage_type == StageType.START and node.in_start_block:
137
+ node.pipeline_weight = block_pipeline_weight
138
+ if node.stage_type == StageType.END:
139
+ if stage_type == StageType.END:
140
+ node.pipeline_weight = num_micro_batches - 1
141
+ node.cost_level = 1
142
+ else:
143
+ node.cost_level = 0
144
+ if stage_type == StageType.END and node.in_end_block:
145
+ node.pipeline_weight = block_pipeline_weight
146
+ if isinstance(node, P2PNode):
147
+ if (graph_config.has_cross_host
148
+ and node.p2p_type == P2PType.CROSS_HOST) or (
149
+ not graph_config.has_cross_host
150
+ and node.p2p_type == P2PType.CROSS_DEVICE):
151
+ if stage_type == StageType.BLOCK:
152
+ node.pipeline_weight += num_micro_batches - 1
153
+ node.cost_level = 1
154
+ else:
155
+ node.cost_level = 0
156
+ elif (graph_config.has_cross_device
157
+ and node.p2p_type == P2PType.CROSS_DEVICE) or (
158
+ not graph_config.has_cross_device
159
+ and node.p2p_type == P2PType.CROSS_HOST):
160
+ node.pipeline_weight += num_micro_batches - 1
161
+ if stage_type == StageType.BLOCK and node.in_slowest_block:
162
+ node.pipeline_weight = block_pipeline_weight
163
+
164
+ def get_cost_graph(self, lmesh):
165
+ leaf_strategies = []
166
+ for node in self.nodes:
167
+ if node.is_replicated:
168
+ node.set_strategy(None, lmesh)
169
+ else:
170
+ node.collect_strategies(lmesh)
171
+ for node in self.nodes:
172
+ strategies_vector = node.update_resharding_cost()
173
+ if len(strategies_vector) != 0:
174
+ leaf_strategies.append(strategies_vector)
175
+ cost_graph = CostGraph(leaf_strategies)
176
+ return cost_graph
177
+
178
+ def find_solution(self, cost_graph, memory_budget):
179
+ solver = Solver(cost_graph, memory_budget=memory_budget)
180
+ solution = solver.find_solution()[1]
181
+
182
+ graph_strategy = solution.node_best_strategy
183
+ for node_name, strategy in graph_strategy.items():
184
+ node = self._nodes[node_name]
185
+ for idx, pre_node in enumerate(node.predecessor_nodes):
186
+ if pre_node is None:
187
+ continue
188
+ if pre_node.node_name not in strategy.best_resharding_cost:
189
+ continue
190
+ strategy.best_resharding_cost[
191
+ idx] = strategy.best_resharding_cost[pre_node.node_name]
192
+ strategy.node_names[idx] = pre_node.node_name
193
+ for key in list(strategy.best_resharding_cost.keys()):
194
+ if isinstance(key, str):
195
+ del strategy.best_resharding_cost[key]
196
+
197
+ return solution
198
+
199
+ def visualize(self, name='pp_graph'):
200
+ with open(name + '.dot', 'w') as f:
201
+ f.write("digraph {\n")
202
+ '''
203
+ f.write(" // Value Nodes\n")
204
+ for name, tensor in self._tensors.items():
205
+ f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
206
+ '''
207
+ f.write(" // Operation Nodes\n")
208
+ for name, node in self._nodes.items():
209
+ fillcolor = 'white'
210
+ if 'MATRIX_MULTIPLY' in name:
211
+ fillcolor = 'green'
212
+ label = name
213
+ if len(node.outputs) > 0:
214
+ label = name + '\\n' + str(node.outputs[0].shape)
215
+ f.write(
216
+ " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
217
+ .format(name, fillcolor, label))
218
+ f.write(" // Edges\n")
219
+ for name, node in self._nodes.items():
220
+ for successor_node in node.successor_nodes:
221
+ if successor_node:
222
+ f.write(" \"{}\" ->\"{}\";\n".format(
223
+ name, successor_node.node_name))
224
+ f.write(" }\n")
225
+
226
+ def visualize_solution(self,
227
+ solution,
228
+ fname='pp_graph_solution',
229
+ ignore_shape_io=True):
230
+ with open(fname + '.dot', 'w') as f:
231
+ names, costs, block_ids = [], [], []
232
+ f.write("digraph {\n")
233
+ f.write(" // Operation Nodes\n")
234
+ for name, node in self._nodes.items():
235
+ if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
236
+ continue
237
+ cost = 0.0
238
+ fillcolor = 'white'
239
+ if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
240
+ fillcolor = 'orange'
241
+ elif '_same_spec' in name:
242
+ fillcolor = 'gray'
243
+ elif 'p2p_block' in name:
244
+ fillcolor = 'blue'
245
+ elif 'PLUGIN' in name:
246
+ fillcolor = 'yellow'
247
+
248
+ shape = 'box'
249
+ if 'output_node' == node.node_type or 'input_node' == node.node_type:
250
+ shape = 'ellipse'
251
+ fillcolor = 'green'
252
+
253
+ label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
254
+ if len(node.inputs) > 0:
255
+ for idx, input in enumerate(node.inputs):
256
+ if not input:
257
+ continue
258
+ label = label + f'\\ninput{idx}_' + str(
259
+ input.shape) + f'_{input.dtype_str_size[0]}_'
260
+ if node.node_name in solution.node_best_strategy:
261
+ best_strategy = solution.node_best_strategy[
262
+ node.node_name]
263
+ shard_seq = str(
264
+ best_strategy.sharding_specs[f'input{idx}'].
265
+ sharding_sequence)
266
+ label = label + shard_seq
267
+ if idx not in best_strategy.best_resharding_cost:
268
+ continue
269
+ rcosts = best_strategy.best_resharding_cost[idx][0]
270
+ comm_action_sequence, resharding_cost = rcosts[
271
+ 1], rcosts[2]
272
+ if len(comm_action_sequence) > 0:
273
+ label = label + '|'
274
+ for commspec in comm_action_sequence:
275
+ comm = [
276
+ commspec.comm_pattern, commspec.gather_dim,
277
+ commspec.shard_dim,
278
+ commspec.logical_process_axis
279
+ ]
280
+ label = label + '->' + str(comm)
281
+ if resharding_cost > 0:
282
+ label = label + '_rcost{:.2}'.format(
283
+ resharding_cost)
284
+ cost = cost + resharding_cost
285
+ if len(node.outputs) > 0:
286
+ best_strategy = None
287
+ for idx, output in enumerate(node.outputs):
288
+ label = label + f'\\noutput{idx}_' + str(
289
+ output.shape) + f'_{output.dtype_str_size[0]}'
290
+ if node.node_name in solution.node_best_strategy:
291
+ best_strategy = solution.node_best_strategy[
292
+ node.node_name]
293
+ shard_seq = str(
294
+ best_strategy.sharding_specs[f'output{idx}'].
295
+ sharding_sequence)
296
+ comm = None
297
+ if f'output{idx}' in best_strategy.communication_actions:
298
+ commspec = best_strategy.communication_actions[
299
+ f'output{idx}']
300
+ comm = [
301
+ commspec.comm_pattern, commspec.gather_dim,
302
+ commspec.shard_dim,
303
+ commspec.logical_process_axis
304
+ ]
305
+ label = label + '_' + shard_seq
306
+ if comm:
307
+ label = label + f' | {comm}'
308
+ if best_strategy:
309
+ cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
310
+ label = label + '| scost{:.2}'.format(
311
+ best_strategy.sharding_cost)
312
+ if best_strategy.communication_cost > 0:
313
+ label = label + ' | ccost{:.2}'.format(
314
+ best_strategy.communication_cost)
315
+ names.append(name)
316
+ costs.append(cost)
317
+ block_ids.append([
318
+ node.building_block_id, node.cost_level,
319
+ node.sharding_weight + node.pipeline_weight,
320
+ node.same_spec_id
321
+ ])
322
+ f.write(
323
+ " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
324
+ .format(name, fillcolor, label, shape))
325
+ f.write(" // Edges\n")
326
+ for name, node in self._nodes.items():
327
+ if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
328
+ continue
329
+ for successor_node in node.successor_nodes:
330
+ if successor_node:
331
+ if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
332
+ continue
333
+ f.write(" \"{}\" ->\"{}\";\n".format(
334
+ name, successor_node.node_name))
335
+ f.write(" }\n")
336
+ df = pd.DataFrame.from_dict({
337
+ 'node':
338
+ names,
339
+ 'cost':
340
+ costs,
341
+ 'block_id': [block[0] for block in block_ids],
342
+ 'cost_level': [block[1] for block in block_ids],
343
+ 'sharding_weight': [block[2] for block in block_ids],
344
+ 'same_spec_id': [block[3] for block in block_ids]
345
+ })
346
+ df['weight_cost'] = df['sharding_weight'] * df['cost']
347
+ df.to_csv(fname + '.csv')
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py ADDED
The diff for this file is too large to render. See raw diff
 
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/pipeline_graph.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import tensorrt as trt
6
+ import torch
7
+
8
+ from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch
9
+ from tensorrt_llm.logger import logger
10
+ from tensorrt_llm.network import Network, get_plugin_info, set_plugin_info
11
+ from tensorrt_llm.plugin.plugin import PluginConfig
12
+ from tensorrt_llm.runtime.session import Session
13
+
14
+ from .utils import (current_flags, get_builder_flags, get_sorted_layer_ids,
15
+ get_strongly_typed, get_trt_network, set_trt_network,
16
+ to_base_class_layer, to_subclass_layer)
17
+
18
+
19
+ class Tensor:
20
+
21
+ def __init__(self, graph: "PipelineGraph"):
22
+ self._graph = graph
23
+ self._trt = None
24
+ self._shape = None
25
+ self._max_shape = None
26
+ self._value = None
27
+ self.producer: Layer = None
28
+ self.output_index = None
29
+ self.consumers = []
30
+ self.graph_input_index = -1
31
+ self.graph_output_index = -1
32
+ self.attrs = {}
33
+
34
+ @staticmethod
35
+ def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor):
36
+ tensor = Tensor(graph)
37
+ tensor._trt = trt_tensor
38
+ return tensor
39
+
40
+ def as_trt(self) -> trt.ITensor:
41
+ return self._trt
42
+
43
+ def copy(self) -> "Tensor":
44
+ tensor = Tensor(self._graph)
45
+ tensor._trt = self._trt
46
+ tensor._shape = self._shape
47
+ tensor._max_shape = self._max_shape
48
+ tensor._value = self._value
49
+ tensor.producer = self.producer
50
+ tensor.output_index = self.output_index
51
+ tensor.consumers = [*self.consumers]
52
+ tensor.graph_input_index = self.graph_input_index
53
+ tensor.graph_output_index = self.graph_output_index
54
+ tensor.attrs = self.attrs.copy()
55
+ return tensor
56
+
57
+ @property
58
+ def graph(self) -> "PipelineGraph":
59
+ return self._graph
60
+
61
+ @property
62
+ def name(self) -> str:
63
+ return self._trt.name
64
+
65
+ @name.setter
66
+ def name(self, name: str):
67
+ old_name = self._trt.name
68
+ if name != old_name:
69
+ self._trt.name = name
70
+ self.graph._tensors[name] = self
71
+ del self.graph._tensors[old_name]
72
+ if self.is_graph_input:
73
+ self.graph._inputs[name] = self
74
+ del self.graph._inputs[old_name]
75
+ elif self.is_graph_output:
76
+ self.graph._outputs[name] = self
77
+ del self.graph._outputs[old_name]
78
+
79
+ @property
80
+ def shape(self):
81
+ return self._shape
82
+
83
+ @property
84
+ def max_shape(self):
85
+ return self._max_shape
86
+
87
+ @property
88
+ def raw_shape(self):
89
+ assert isinstance(self._trt, trt.ITensor)
90
+ return self._trt.shape
91
+
92
+ @shape.setter
93
+ def shape(self, shape):
94
+ self._shape = shape
95
+
96
+ @max_shape.setter
97
+ def max_shape(self, max_shape):
98
+ self._max_shape = max_shape
99
+
100
+ @raw_shape.setter
101
+ def raw_shape(self, raw_shape):
102
+ assert isinstance(self._trt, trt.ITensor)
103
+ self._trt.shape = raw_shape
104
+
105
+ @property
106
+ def value(self):
107
+ return self._value
108
+
109
+ @value.setter
110
+ def value(self, value):
111
+ self._value = value
112
+
113
+ @property
114
+ def dtype(self):
115
+ return self._trt.dtype
116
+
117
+ @property
118
+ def broadcast_across_batch(self):
119
+ return self._trt.broadcast_across_batch
120
+
121
+ @property
122
+ def dtype_size(self):
123
+ return self.dtype.itemsize
124
+
125
+ @property
126
+ def dtype_str(self):
127
+ return trt_dtype_to_str(self.dtype)
128
+
129
+ @property
130
+ def dtype_str_size(self):
131
+ return [trt_dtype_to_str(self.dtype), self.dtype.itemsize]
132
+
133
+ @property
134
+ def is_graph_input(self) -> bool:
135
+ return self.graph_input_index != -1
136
+
137
+ @property
138
+ def is_graph_output(self) -> bool:
139
+ return self.graph_output_index != -1
140
+
141
+ @property
142
+ def is_graph_io(self) -> bool:
143
+ return self.is_graph_input or self.is_graph_output
144
+
145
+
146
+ class Layer:
147
+
148
+ def __init__(self, graph):
149
+ self._graph = graph
150
+ self._trt = None
151
+ self._index = None
152
+ self._inputs = []
153
+ self._outputs = []
154
+ self._is_shape_io = False
155
+ self.attrs = {}
156
+
157
+ @staticmethod
158
+ def from_trt(graph, trt_layer, index):
159
+ layer = Layer(graph)
160
+ layer._trt = trt_layer
161
+ layer._index = index
162
+ for i in range(trt_layer.num_inputs):
163
+ input = trt_layer.get_input(i)
164
+ if input is not None:
165
+ layer._inputs.append(graph.get_tensor(input.name))
166
+ layer._inputs[i].consumers.append((layer, i))
167
+ else:
168
+ layer._inputs.append(None)
169
+ for i in range(trt_layer.num_outputs):
170
+ output = trt_layer.get_output(i)
171
+ layer._outputs.append(graph.get_tensor(output.name))
172
+ layer._outputs[i].producer = layer
173
+ layer._outputs[i].output_index = i
174
+ set_trt_network(trt_layer, graph.as_trt())
175
+ return layer
176
+
177
+ def as_trt(self) -> trt.ILayer:
178
+ return self._trt
179
+
180
+ @property
181
+ def graph(self) -> "PipelineGraph":
182
+ return self._graph
183
+
184
+ @property
185
+ def name(self) -> str:
186
+ return self._trt.name
187
+
188
+ @name.setter
189
+ def name(self, name: str):
190
+ old_name = self._trt.name
191
+ if name != old_name:
192
+ self._trt.name = name
193
+ self.graph._layers[name] = self
194
+ del self.graph._layers[old_name]
195
+
196
+ @property
197
+ def type(self) -> trt.LayerType:
198
+ return self._trt.type
199
+
200
+ @property
201
+ def index(self) -> int:
202
+ return self._index
203
+
204
+ @property
205
+ def inputs(self) -> List[Tensor]:
206
+ return self._inputs
207
+
208
+ @property
209
+ def outputs(self) -> List[Tensor]:
210
+ return self._outputs
211
+
212
+ def get_input(self, index: int) -> Tensor:
213
+ return self._inputs[index]
214
+
215
+ def get_output(self, index: int) -> Tensor:
216
+ return self._outputs[index]
217
+
218
+ @property
219
+ def num_inputs(self) -> int:
220
+ return self._trt.num_inputs
221
+
222
+ @property
223
+ def num_outputs(self) -> int:
224
+ return self._trt.num_outputs
225
+
226
+ @property
227
+ def is_shape_io(self) -> bool:
228
+ return self._is_shape_io
229
+
230
+ def to_subclass(self):
231
+ to_subclass_layer(self._trt)
232
+
233
+ def to_base_class(self):
234
+ to_base_class_layer(self._trt)
235
+
236
+ def assign_shapes(self, shapes, values):
237
+ for output in self.outputs:
238
+ output.shape = shapes[output.name]
239
+ output.value = values.get(output.name)
240
+
241
+
242
+ @dataclass
243
+ class GraphRunner:
244
+ session: Session
245
+ inputs: Dict[str, torch.Tensor]
246
+ outputs: Dict[str, torch.Tensor]
247
+ stream: torch.Stream
248
+
249
+ def run(self):
250
+ cuda_stream = self.stream.cuda_stream
251
+ assert self.session.run(self.inputs, self.outputs, cuda_stream)
252
+ self.stream.synchronize()
253
+ return self.outputs
254
+
255
+
256
+ class PipelineGraph:
257
+
258
+ def __init__(self):
259
+ self._trt = None
260
+ self._inputs: Dict[str, Tensor] = {}
261
+ self._outputs: Dict[str, Tensor] = {}
262
+ self._layers: Dict[str, Layer] = {}
263
+ self._tensors: Dict[str, Tensor] = {}
264
+ self._io_buffer_mapping = {}
265
+ self._unfilled_weights = {}
266
+ self._auto_parallel_config = None
267
+ self._plugin_config: PluginConfig = None
268
+
269
+ @staticmethod
270
+ def create_graph():
271
+ graph = PipelineGraph()
272
+ trt_builder = trt.Builder(logger.trt_logger)
273
+ explicit_batch_flag = 0
274
+ # Explicit batch flag will be deprecated in TRT 10
275
+ if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys(
276
+ ):
277
+ explicit_batch_flag = 1 << int(
278
+ trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
279
+ if get_strongly_typed():
280
+ network = trt_builder.create_network(
281
+ explicit_batch_flag
282
+ | (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)))
283
+ else:
284
+ network = trt_builder.create_network(explicit_batch_flag)
285
+ graph._trt = network
286
+ return graph
287
+
288
+ def _register_unfilled_weights(self, layer_name, weights, values):
289
+ self._unfilled_weights[layer_name] = (weights, values)
290
+
291
+ def _add_tensor(self, tensor, old_tensor, prefix):
292
+ if prefix is not None:
293
+ tensor.name = prefix + old_tensor.name
294
+ else:
295
+ tensor.name = old_tensor.name
296
+ tensor.location = old_tensor.location
297
+ if old_tensor.dynamic_range is not None:
298
+ tensor.dynamic_range = old_tensor.dynamic_range
299
+ if tensor.is_network_input:
300
+ tensor.shape = old_tensor.shape
301
+ for i in range(len(old_tensor.shape)):
302
+ name = old_tensor.get_dimension_name(i)
303
+ if name is not None:
304
+ tensor.set_dimension_name(i, name)
305
+ return self._register_tensor(tensor)
306
+
307
+ def _register_tensor(self, tensor):
308
+ wrapped_tensor = Tensor.from_trt(self, tensor)
309
+ assert tensor.name not in self._tensors
310
+ self._tensors[tensor.name] = wrapped_tensor
311
+ return wrapped_tensor
312
+
313
+ def add_input(self, tensor, prefix=None):
314
+ tensor_name = tensor.name
315
+ if prefix is not None:
316
+ tensor_name = prefix + tensor_name
317
+ input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape)
318
+ new_tensor = self._add_tensor(input, tensor, prefix)
319
+ new_tensor.graph_input_index = len(self._inputs)
320
+ self._inputs[tensor_name] = new_tensor
321
+ return new_tensor
322
+
323
+ def register_input(self, tensor, index=None):
324
+ if index is None:
325
+ index = self.num_inputs - 1
326
+ assert self._trt.get_input(index).name == tensor.name
327
+ wrapped_input = self._register_tensor(tensor)
328
+ wrapped_input.graph_input_index = index
329
+ self._inputs[tensor.name] = wrapped_input
330
+ return wrapped_input
331
+
332
+ def add_output(self, tensor, prefix=None):
333
+ tensor_name = tensor.name
334
+ if prefix is not None:
335
+ tensor_name = prefix + tensor_name
336
+ output = self.get_tensor(tensor_name)
337
+ output.graph_output_index = len(self._outputs)
338
+ trt_output = output.as_trt()
339
+ self._trt.mark_output(trt_output)
340
+ trt_output.dtype = tensor.dtype
341
+ self._outputs[tensor_name] = output
342
+ return output
343
+
344
+ def add_output_shape(self, tensor, prefix=None):
345
+ tensor_name = tensor.name
346
+ if prefix is not None:
347
+ tensor_name = prefix + tensor_name
348
+ output = self.get_tensor(tensor_name)
349
+ trt_output = output.as_trt()
350
+ self._trt.mark_output_for_shapes(trt_output)
351
+ trt_output.dtype = tensor.dtype
352
+ self._outputs[tensor_name] = output
353
+ return output
354
+
355
+ def add_layer(
356
+ self,
357
+ layer,
358
+ input_mapping=None,
359
+ prefix=None,
360
+ updated_attrs=None,
361
+ ) -> Layer:
362
+
363
+ def get_input(i):
364
+ name = layer.get_input(i).name
365
+ if prefix is not None:
366
+ name = prefix + name
367
+ if input_mapping is not None and name in input_mapping:
368
+ name = input_mapping[name]
369
+ return self.get_tensor(name).as_trt()
370
+
371
+ network = self._trt
372
+ layer_type = layer.type
373
+ to_subclass_layer(layer)
374
+ if layer_type == trt.LayerType.ACTIVATION:
375
+ trt_input = get_input(0)
376
+ new_layer = network.add_activation(trt_input, layer.type)
377
+ new_layer.alpha = layer.alpha
378
+ new_layer.beta = layer.beta
379
+ elif layer_type == trt.LayerType.CONCATENATION:
380
+ trt_inputs = [get_input(i) for i in range(layer.num_inputs)]
381
+ new_layer = network.add_concatenation(trt_inputs)
382
+ new_layer.axis = layer.axis
383
+ elif layer_type == trt.LayerType.CONSTANT:
384
+ new_layer = network.add_constant(layer.shape, layer.weights)
385
+ elif layer_type == trt.LayerType.ELEMENTWISE:
386
+ new_layer = network.add_elementwise(get_input(0), get_input(1),
387
+ layer.op)
388
+ elif layer_type == trt.LayerType.FILL:
389
+ if layer.num_inputs >= 1 and layer.get_input(0) is not None:
390
+ shape_input = get_input(0)
391
+ shape = [1]
392
+ else:
393
+ shape_input = None
394
+ shape = layer.shape
395
+ new_layer = network.add_fill(shape, layer.operation, layer.to_type)
396
+ if shape_input is not None:
397
+ new_layer.set_input(0, shape_input)
398
+ if layer.num_inputs >= 1 and layer.get_input(0) is not None:
399
+ new_layer.set_input(0, get_input(0))
400
+ if layer.num_inputs >= 2 and layer.get_input(1) is not None:
401
+ new_layer.set_input(1, get_input(1))
402
+ else:
403
+ new_layer.alpha = layer.alpha
404
+ if layer.num_inputs >= 3 and layer.get_input(2) is not None:
405
+ new_layer.set_input(2, get_input(2))
406
+ else:
407
+ new_layer.beta = layer.beta
408
+ elif layer_type == trt.LayerType.GATHER:
409
+ trt_input = get_input(0)
410
+ trt_indices = get_input(1)
411
+ new_layer = network.add_gather_v2(trt_input, trt_indices,
412
+ layer.mode)
413
+ new_layer.axis = layer.axis
414
+ new_layer.num_elementwise_dims = layer.num_elementwise_dims
415
+ new_layer.mode = layer.mode
416
+ elif layer_type == trt.LayerType.MATRIX_MULTIPLY:
417
+ new_layer = network.add_matrix_multiply(get_input(0), layer.op0,
418
+ get_input(1), layer.op1)
419
+ elif layer_type == trt.LayerType.REDUCE:
420
+ new_layer = network.add_reduce(get_input(0), layer.op, layer.axes,
421
+ layer.keep_dims)
422
+ elif layer_type == trt.LayerType.SELECT:
423
+ trt_condition = get_input(0)
424
+ trt_then = get_input(1)
425
+ trt_else = get_input(2)
426
+ new_layer = network.add_select(trt_condition, trt_then, trt_else)
427
+ elif layer_type == trt.LayerType.SHUFFLE:
428
+ new_layer = network.add_shuffle(get_input(0))
429
+ new_layer.first_transpose = layer.first_transpose
430
+ new_layer.second_transpose = layer.second_transpose
431
+ new_layer.zero_is_placeholder = layer.zero_is_placeholder
432
+ if layer.num_inputs >= 2:
433
+ trt_reshape_dims_tensor = get_input(1)
434
+ new_layer.set_input(1, trt_reshape_dims_tensor)
435
+ else:
436
+ new_layer.reshape_dims = layer.reshape_dims
437
+ elif layer_type == trt.LayerType.SLICE:
438
+ if layer.num_inputs >= 2 and layer.get_input(1) is not None:
439
+ trt_start = get_input(1)
440
+ start = []
441
+ else:
442
+ trt_start = None
443
+ start = layer.start
444
+ if layer.num_inputs >= 3 and layer.get_input(2) is not None:
445
+ trt_shape = get_input(2)
446
+ shape = []
447
+ else:
448
+ trt_shape = None
449
+ shape = layer.shape
450
+ if layer.num_inputs >= 4 and layer.get_input(3) is not None:
451
+ trt_stride = get_input(3)
452
+ stride = []
453
+ else:
454
+ trt_stride = None
455
+ stride = layer.stride
456
+ new_layer = network.add_slice(get_input(0), start, shape, stride)
457
+ new_layer.mode = layer.mode
458
+ if trt_start is not None:
459
+ new_layer.set_input(1, trt_start)
460
+ if trt_shape is not None:
461
+ new_layer.set_input(2, trt_shape)
462
+ if trt_stride is not None:
463
+ new_layer.set_input(3, trt_stride)
464
+ elif layer_type == trt.LayerType.SOFTMAX:
465
+ new_layer = network.add_softmax(get_input(0))
466
+ new_layer.axes = layer.axes
467
+ elif layer_type == trt.LayerType.UNARY:
468
+ new_layer = network.add_unary(get_input(0), layer.op)
469
+ elif layer_type == trt.LayerType.SHAPE:
470
+ new_layer = network.add_shape(get_input(0))
471
+ elif layer_type == trt.LayerType.ASSERTION:
472
+ new_layer = network.add_assertion(get_input(0), layer.message)
473
+ elif layer_type == trt.LayerType.CAST:
474
+ new_layer = network.add_cast(get_input(0), layer.to_type)
475
+ elif layer_type == trt.LayerType.NORMALIZATION:
476
+ trt_input = get_input(0)
477
+ trt_scale = get_input(1)
478
+ trt_bias = get_input(2)
479
+ new_layer = network.add_normalization(trt_input, trt_scale,
480
+ trt_bias, layer.axes)
481
+ new_layer.epsilon = layer.epsilon
482
+ new_layer.num_groups = layer.num_groups
483
+ new_layer.compute_precision = layer.compute_precision
484
+ elif layer_type == trt.LayerType.IDENTITY:
485
+ new_layer = network.add_identity(get_input(0))
486
+ elif layer_type == trt.LayerType.PLUGIN_V2:
487
+ plugin = layer.plugin
488
+ updated = False
489
+ if (updated_attrs is not None
490
+ and updated_attrs.get("plugin") is not None):
491
+ plugin = updated_attrs["plugin"]
492
+ updated = True
493
+ updated_attrs = None
494
+ new_layer = network.add_plugin_v2(
495
+ [get_input(i) for i in range(layer.num_inputs)],
496
+ plugin,
497
+ )
498
+ else:
499
+ raise NotImplementedError(
500
+ "Unsupported layer type: {}".format(layer_type))
501
+
502
+ if updated_attrs is not None:
503
+ for attr_name, attr_value in updated_attrs.items():
504
+ setattr(new_layer, attr_name, attr_value)
505
+
506
+ to_base_class_layer(layer)
507
+ to_base_class_layer(new_layer)
508
+ layer_index = network.num_layers - 1
509
+ layer_name = layer.name
510
+ if prefix is not None:
511
+ layer_name = prefix + layer_name
512
+ new_layer.name = layer_name
513
+ new_layer.metadata = new_layer.name
514
+ if layer.precision_is_set:
515
+ new_layer.precision = layer.precision
516
+ for i in range(layer.num_outputs):
517
+ if layer.output_type_is_set(i):
518
+ new_layer.set_output_type(i, layer.get_output_type(i))
519
+ output = new_layer.get_output(i)
520
+ self._add_tensor(output, layer.get_output(i), prefix)
521
+ wrapped_layer = Layer.from_trt(self, new_layer, layer_index)
522
+ assert layer_name not in self._layers
523
+ self._layers[layer_name] = wrapped_layer
524
+ if layer_type == trt.LayerType.PLUGIN_V2:
525
+ if not updated:
526
+ plugin_info = get_plugin_info(get_trt_network(layer),
527
+ layer.name)
528
+ set_plugin_info(self.as_trt(), new_layer.name, plugin_info)
529
+ return wrapped_layer
530
+
531
+ def register_layer(self, layer, index=None):
532
+ if index is None:
533
+ index = self.num_layers - 1
534
+ assert self._trt.get_layer(index).name == layer.name
535
+ to_base_class_layer(layer)
536
+ for i in range(layer.num_outputs):
537
+ output = layer.get_output(i)
538
+ self._register_tensor(output)
539
+ wrapped_layer = Layer.from_trt(self, layer, index)
540
+ assert layer.name not in self._layers
541
+ self._layers[layer.name] = wrapped_layer
542
+ to_subclass_layer(layer)
543
+ return wrapped_layer
544
+
545
+ def get_runner(
546
+ self,
547
+ shapes=None,
548
+ values=None,
549
+ profile=None,
550
+ timing_cache=None,
551
+ opt_level=None,
552
+ ) -> GraphRunner:
553
+ shapes = shapes or {}
554
+ values = values or {}
555
+ inputs = {}
556
+ outputs = {}
557
+ for input in self.inputs:
558
+ if input is not None:
559
+ value = values.get(input.name)
560
+ if value is None:
561
+ value = input.value
562
+ if value is not None:
563
+ if not isinstance(value, torch.Tensor):
564
+ value = torch.tensor(
565
+ value,
566
+ dtype=trt_dtype_to_torch(input.dtype),
567
+ device='cpu',
568
+ )
569
+ inputs[input.name] = value
570
+ else:
571
+ shape = shapes.get(input.name)
572
+ if shape is None:
573
+ shape = input.shape
574
+ assert shape is not None
575
+ inputs[input.name] = torch.empty(
576
+ tuple(shape),
577
+ dtype=trt_dtype_to_torch(input.dtype),
578
+ device=torch.cuda.current_device(),
579
+ )
580
+ if torch.is_floating_point(inputs[input.name]):
581
+ inputs[input.name].normal_()
582
+ # inputs[input.name][:] = random.choice([2, 3, 5, 7])
583
+ for output in self.outputs:
584
+ if output.as_trt().is_shape_tensor:
585
+ continue
586
+ if output.name in self._io_buffer_mapping:
587
+ input_name = self._io_buffer_mapping[output.name]
588
+ if input_name in inputs:
589
+ outputs[output.name] = inputs[input_name]
590
+ continue
591
+ value = values.get(output.name)
592
+ if value is not None and isinstance(value, torch.Tensor):
593
+ outputs[output.name] = value
594
+ else:
595
+ shape = shapes.get(output.name)
596
+ if shape is None:
597
+ shape = output.shape
598
+ assert shape is not None
599
+ outputs[output.name] = torch.empty(
600
+ tuple(shape),
601
+ dtype=trt_dtype_to_torch(output.dtype),
602
+ device=torch.cuda.current_device(),
603
+ )
604
+ network = self.as_trt()
605
+ config = network.builder.create_builder_config()
606
+ if opt_level is not None:
607
+ config.builder_optimization_level = opt_level
608
+ config.flags = get_builder_flags()
609
+ profile = profile or network.builder.create_optimization_profile()
610
+ profile_index = config.add_optimization_profile(profile)
611
+ if timing_cache is not None:
612
+ config.set_timing_cache(timing_cache, ignore_mismatch=False)
613
+ plan = network.builder.build_serialized_network(network, config)
614
+ if plan is None:
615
+ logger.error('Engine building failed, please check the error log.')
616
+ session = Session.from_serialized_engine(plan)
617
+ stream = torch.cuda.current_stream()
618
+ cuda_stream = stream.cuda_stream
619
+ context = session.context
620
+ context.set_optimization_profile_async(profile_index, cuda_stream)
621
+ runner = GraphRunner(session, inputs, outputs, stream)
622
+ return runner
623
+
624
+ def run(
625
+ self,
626
+ shapes=None,
627
+ values=None,
628
+ profile=None,
629
+ timing_cache=None,
630
+ opt_level=None,
631
+ ):
632
+ return self.get_runner(
633
+ shapes,
634
+ values,
635
+ profile,
636
+ timing_cache,
637
+ opt_level,
638
+ ).run()
639
+
640
+ def duplicate_graph(self):
641
+ graph = PipelineGraph.create_graph()
642
+ network = self.as_trt()
643
+ for i in range(network.num_inputs):
644
+ input = network.get_input(i)
645
+ graph.add_input(input)
646
+ sorted_layer_ids = get_sorted_layer_ids(network)
647
+ for i in sorted_layer_ids:
648
+ layer = network.get_layer(i)
649
+ graph.add_layer(layer)
650
+ for i in range(network.num_outputs):
651
+ output = network.get_output(i)
652
+ if output.is_shape_tensor:
653
+ graph.add_output_shape(output)
654
+ else:
655
+ graph.add_output(output)
656
+ return graph
657
+
658
+ @staticmethod
659
+ def from_trt(trt_network):
660
+ graph = PipelineGraph()
661
+ graph._trt = trt_network
662
+
663
+ # construct inputs and tensors
664
+ for i in range(trt_network.num_inputs):
665
+ trt_input = trt_network.get_input(i)
666
+ tensor = Tensor.from_trt(graph, trt_input)
667
+ tensor.graph_input_index = i
668
+ graph._tensors[tensor.name] = tensor
669
+ graph._inputs[tensor.name] = tensor
670
+ for i in range(trt_network.num_layers):
671
+ trt_layer = trt_network.get_layer(i)
672
+ for i in range(trt_layer.num_outputs):
673
+ trt_output = trt_layer.get_output(i)
674
+ tensor = Tensor.from_trt(graph, trt_output)
675
+ graph._tensors[tensor.name] = tensor
676
+
677
+ # construct layers and outputs
678
+ for i in range(trt_network.num_layers):
679
+ layer = Layer.from_trt(graph, trt_network.get_layer(i), i)
680
+ graph._layers[layer.name] = layer
681
+ for i in range(trt_network.num_outputs):
682
+ tensor_name = trt_network.get_output(i).name
683
+ output_tensor = graph._tensors[tensor_name]
684
+ output_tensor.graph_output_index = i
685
+ graph._outputs[tensor_name] = output_tensor
686
+
687
+ return graph
688
+
689
+ @staticmethod
690
+ def from_network(network: Network, builder_config):
691
+ builder_flags = builder_config.trt_builder_config.flags
692
+ with current_flags(builder_flags, network.strongly_typed):
693
+ graph = PipelineGraph.from_trt(network.trt_network)
694
+ graph.infer_shapes(network._generate_optimization_profiles()[-1])
695
+ return graph
696
+
697
+ def assign_shapes(self, shape_info=None, is_partial=False):
698
+ if shape_info is None:
699
+ for tensor in self.tensors:
700
+ tensor.shape = tensor.raw_shape
701
+ return
702
+ for tensor in self.tensors:
703
+ if tensor.name in shape_info.shapes:
704
+ tensor.shape = shape_info.shapes[tensor.name]
705
+ elif not is_partial:
706
+ raise ValueError(f"Cannot find shape for tensor: {tensor.name}")
707
+ if shape_info.max_shapes is not None:
708
+ if tensor.name in shape_info.max_shapes:
709
+ tensor.max_shape = shape_info.max_shapes[tensor.name]
710
+ elif not is_partial:
711
+ raise ValueError(
712
+ f"Cannot find max shape for tensor: {tensor.name}")
713
+ if tensor.name in shape_info.values:
714
+ tensor.value = shape_info.values[tensor.name]
715
+ for layer in self.layers:
716
+ if layer.name in shape_info.shape_layers:
717
+ layer._is_shape_io = True
718
+
719
+ def infer_shapes(self, profile=None):
720
+ from .shape_info import get_shape_info
721
+
722
+ shape_info = get_shape_info(self._trt, profile)
723
+ self.assign_shapes(shape_info)
724
+
725
+ def as_trt(self) -> trt.INetworkDefinition:
726
+ return self._trt
727
+
728
+ def get_input(self, name: str) -> Tensor:
729
+ return self._inputs.get(name)
730
+
731
+ def is_input(self, name: str) -> bool:
732
+ return name in self._inputs
733
+
734
+ @property
735
+ def inputs(self) -> List[Tensor]:
736
+ return [*self._inputs.values()]
737
+
738
+ @property
739
+ def num_inputs(self) -> int:
740
+ return self._trt.num_inputs
741
+
742
+ def get_output(self, name: str) -> Tensor:
743
+ return self._outputs.get(name)
744
+
745
+ def is_output(self, name: str) -> bool:
746
+ return name in self._outputs
747
+
748
+ @property
749
+ def outputs(self) -> List[Tensor]:
750
+ return [*self._outputs.values()]
751
+
752
+ @property
753
+ def num_outputs(self) -> int:
754
+ return self._trt.num_outputs
755
+
756
+ def get_tensor(self, name: str) -> Tensor:
757
+ return self._tensors.get(name)
758
+
759
+ @property
760
+ def tensors(self) -> List[Tensor]:
761
+ return [*self._tensors.values()]
762
+
763
+ def get_layer(self, name: str) -> Layer:
764
+ return self._layers.get(name)
765
+
766
+ @property
767
+ def layers(self) -> List[Layer]:
768
+ return [*self._layers.values()]
769
+
770
+ @property
771
+ def sorted_layers(self) -> List[Layer]:
772
+ sorted_layer_ids = get_sorted_layer_ids(self.as_trt())
773
+ return [
774
+ self.get_layer(self.as_trt().get_layer(layer_id).name)
775
+ for layer_id in sorted_layer_ids
776
+ ]
777
+
778
+ @property
779
+ def num_layers(self) -> int:
780
+ return self._trt.num_layers
781
+
782
+ def to_dot(self,
783
+ path=None,
784
+ per_device=False,
785
+ per_block=False,
786
+ ignore_shape_io=False,
787
+ no_style=False,
788
+ extra_attrs=None) -> Optional[str]:
789
+ '''
790
+ Get a graphviz representation of the graph.
791
+
792
+ Parameters:
793
+ path: the path to save the graphviz file, if not provided, will return the graphviz source code
794
+ '''
795
+ try:
796
+ import graphviz
797
+ except ImportError:
798
+ logger.error(
799
+ "Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()"
800
+ )
801
+ return
802
+
803
+ extra_attrs = extra_attrs or []
804
+
805
+ graph = graphviz.Digraph()
806
+ input_block_graph = graphviz.Digraph(name='cluster_inputs')
807
+ output_block_graph = graphviz.Digraph(name='cluster_outputs')
808
+ device_graphs = {}
809
+ block_graphs = {}
810
+ block_graph_mapping = []
811
+ tensor_names = set()
812
+ layer_names = set()
813
+
814
+ common_style = dict(fontname='Arial', )
815
+ node_style = dict(
816
+ **common_style,
817
+ style='rounded,filled,bold',
818
+ )
819
+ tensor_style = dict(
820
+ **node_style,
821
+ shape='ellipse',
822
+ fillcolor='white',
823
+ )
824
+ input_tensor_style = {**tensor_style, 'fillcolor': 'green'}
825
+ output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'}
826
+ layer_style = dict(
827
+ **node_style,
828
+ shape='box',
829
+ fillcolor='white',
830
+ )
831
+ shape_layer_style = {**layer_style, 'fillcolor': 'grey'}
832
+ helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'}
833
+ graph_style = dict(
834
+ **common_style,
835
+ style='rounded',
836
+ penwidth='5',
837
+ fontsize='28',
838
+ )
839
+ device_graph_style = dict(
840
+ **graph_style,
841
+ color='cornflowerblue',
842
+ )
843
+ block_graph_style = dict(
844
+ **graph_style,
845
+ color='darkcyan',
846
+ )
847
+ input_block_style = dict(
848
+ **graph_style,
849
+ color='green',
850
+ )
851
+ output_block_style = dict(
852
+ **graph_style,
853
+ color='lightgreen',
854
+ )
855
+ if no_style:
856
+ device_graph_style = {}
857
+ block_graph_style = {}
858
+ input_block_style = {}
859
+ output_block_style = {}
860
+ input_block_graph.attr(label='inputs', **input_block_style)
861
+ output_block_graph.attr(label='outputs', **output_block_style)
862
+
863
+ def get_tensor_labels(tensor):
864
+ labels = []
865
+ if tensor.value is not None:
866
+ labels.append(f"value={tensor.value}")
867
+ else:
868
+ labels.append(f"dtype={tensor.dtype.name}{tensor.shape}")
869
+ for attr_name in extra_attrs:
870
+ if attr_name in tensor.attrs:
871
+ labels.append(f"{attr_name}={tensor.attrs[attr_name]}")
872
+ return labels
873
+
874
+ def get_device_graph(name):
875
+ if per_device and name.startswith('device'):
876
+ device_name = name.split('_')[0]
877
+ if device_name not in device_graphs:
878
+ device_graph = graphviz.Digraph(name='cluster_' +
879
+ device_name)
880
+ device_graph.attr(label=device_name, **device_graph_style)
881
+ device_graphs[device_name] = device_graph
882
+ return device_graphs[device_name]
883
+ return None
884
+
885
+ def get_block_graph(layer, current_graph):
886
+ if per_block and 'block_id' in layer.attrs:
887
+ block_label = f"block{layer.attrs['block_id']}"
888
+ if current_graph.name is not None:
889
+ graph_label = current_graph.name[len('cluster_'):]
890
+ else:
891
+ graph_label = ''
892
+ block_name = f"{graph_label}{block_label}"
893
+ if block_name not in block_graphs:
894
+ block_graph = graphviz.Digraph(name='cluster_' + block_name)
895
+ block_graph.attr(label=block_label, **block_graph_style)
896
+ block_graphs[block_name] = block_graph
897
+ block_graph_mapping.append((current_graph, block_graph))
898
+ return block_graphs[block_name]
899
+ return current_graph
900
+
901
+ for name, tensor in self._tensors.items():
902
+ style = tensor_style
903
+ if tensor.is_graph_input:
904
+ style = input_tensor_style
905
+ current_graph = input_block_graph
906
+ elif tensor.is_graph_output:
907
+ style = output_tensor_style
908
+ current_graph = output_block_graph
909
+ elif tensor.producer.num_outputs == 1:
910
+ continue
911
+ else:
912
+ current_graph = get_device_graph(name) or graph
913
+ current_graph = get_block_graph(tensor.producer, current_graph)
914
+ if no_style:
915
+ style = {}
916
+ labels = [name, *get_tensor_labels(tensor)]
917
+ content = "\n".join(labels)
918
+ current_graph.node(name, content, **style)
919
+ tensor_names.add(name)
920
+
921
+ for layer in self.sorted_layers:
922
+ name = layer.name
923
+
924
+ style = layer_style
925
+ if layer.is_shape_io:
926
+ if ignore_shape_io:
927
+ continue
928
+ style = shape_layer_style
929
+ elif layer.attrs.get("role", None) == "helper":
930
+ style = helper_layer_style
931
+ fillcolor = None
932
+ plugin_type = None
933
+ if layer.type == trt.LayerType.PLUGIN_V2:
934
+ fillcolor = 'yellow'
935
+ layer.to_subclass()
936
+ plugin_type = layer.as_trt().plugin.plugin_type
937
+ layer.to_base_class()
938
+ if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm':
939
+ fillcolor = 'orange'
940
+ if fillcolor is not None:
941
+ style = {**style, 'fillcolor': fillcolor}
942
+ if no_style:
943
+ style = {}
944
+
945
+ layer_attrs = {}
946
+ layer_type = layer.type
947
+ layer.to_subclass()
948
+ if layer_type == trt.LayerType.CONSTANT:
949
+ if not layer.is_shape_io:
950
+ if trt.volume(layer.get_output(0).shape) <= 8:
951
+ weights = layer.as_trt().weights
952
+ if isinstance(weights, trt.Weights):
953
+ weights = weights.numpy()
954
+ value = np.array2string(
955
+ weights,
956
+ formatter={'float_kind': lambda x: f"{x:.2e}"})
957
+ layer_attrs['value'] = value
958
+ elif layer_type == trt.LayerType.SHUFFLE:
959
+ for attr_name in ['first_transpose', 'second_transpose']:
960
+ attr_value = getattr(layer.as_trt(), attr_name)
961
+ if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7):
962
+ tensor = layer.get_input(
963
+ 0
964
+ ) if attr_name == 'first_transpose' else layer.get_output(
965
+ 0)
966
+ layer_attrs[attr_name] = tuple(
967
+ attr_value)[:len(tensor.shape)]
968
+ if layer.num_inputs < 2:
969
+ attr_value = layer.as_trt().reshape_dims
970
+ layer_attrs['reshape_dims'] = attr_value
971
+ elif layer_type == trt.LayerType.SLICE:
972
+ if layer.num_inputs < 2 or layer.get_input(1) is None:
973
+ layer_attrs['start'] = layer.as_trt().start
974
+ if layer.num_inputs < 4 or layer.get_input(3) is None:
975
+ attr_value = layer.as_trt().stride
976
+ if attr_value != tuple(
977
+ [1] * len(layer.get_output(0).shape)):
978
+ layer_attrs['stride'] = attr_value
979
+ layer.to_base_class()
980
+
981
+ if layer.is_shape_io:
982
+ labels = [layer.type.name]
983
+ else:
984
+ labels = [name, layer.type.name]
985
+ for key, value in layer_attrs.items():
986
+ labels.append(f"{key}={value}")
987
+ for attr_name in extra_attrs:
988
+ if attr_name in layer.attrs:
989
+ labels.append(f"{attr_name}={layer.attrs[attr_name]}")
990
+ if layer.num_outputs == 1:
991
+ output = layer.get_output(0)
992
+ if output.name != f'{layer.name}_output_0':
993
+ labels.append(f"output={output.name}")
994
+ labels.extend(get_tensor_labels(output))
995
+ content = "\n".join(labels)
996
+
997
+ current_graph = get_device_graph(name) or graph
998
+ current_graph = get_block_graph(layer, current_graph)
999
+ current_graph.node(name, content, **style)
1000
+ layer_names.add(name)
1001
+
1002
+ for index, input in enumerate(layer.inputs):
1003
+ if input is not None:
1004
+ if input.is_graph_input or input.producer.num_outputs > 1:
1005
+ if input.name in tensor_names:
1006
+ graph.edge(input.name, name, str(index))
1007
+ else:
1008
+ if input.producer.name in layer_names:
1009
+ graph.edge(input.producer.name, name, str(index))
1010
+ if layer.num_outputs > 1 or (layer.num_outputs == 1 and
1011
+ layer.get_output(0).is_graph_output):
1012
+ for index, output in enumerate(layer.outputs):
1013
+ graph.edge(name, output.name, str(index))
1014
+
1015
+ graph.subgraph(input_block_graph)
1016
+ graph.subgraph(output_block_graph)
1017
+ for parent_graph, block_graph in block_graph_mapping:
1018
+ parent_graph.subgraph(block_graph)
1019
+ for device_graph in device_graphs.values():
1020
+ graph.subgraph(device_graph)
1021
+
1022
+ if not path:
1023
+ return graph.source
1024
+ graph.save(path)
1025
+
1026
+ @staticmethod
1027
+ def trt_to_dot(trt_network, path=None):
1028
+ graph = PipelineGraph.from_trt(trt_network)
1029
+ graph.assign_shapes()
1030
+ dot = graph.to_dot(no_style=True)
1031
+ if path is not None:
1032
+ with open(path, "w") as f:
1033
+ f.write(dot)
1034
+ else:
1035
+ return dot
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/runtime_profiling.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorrt as trt
3
+ import torch
4
+
5
+ from tensorrt_llm.logger import logger
6
+ from tensorrt_llm.network import get_plugin_info
7
+
8
+ from .shape_info import get_per_layer_graph
9
+ from .utils import get_cache_key, get_trt_network, get_updated_plugin
10
+
11
+
12
+ class NvtxProfiler(object):
13
+
14
+ def __init__(self, nvtx_name, enable=True):
15
+ self.nvtx_name = nvtx_name
16
+ self.enable = enable
17
+
18
+ def __enter__(self):
19
+ if self.enable:
20
+ torch.cuda.nvtx.range_push(self.nvtx_name)
21
+
22
+ def __exit__(self, exc_type, exc_val, exc_tb):
23
+ if self.enable:
24
+ torch.cuda.nvtx.range_pop()
25
+
26
+
27
+ class LayerProfiler(trt.IProfiler):
28
+
29
+ def __init__(self):
30
+ trt.IProfiler.__init__(self)
31
+ self.layer_count = 0
32
+ self.time = 0
33
+
34
+ def report_layer_time(self, layer_name, ms):
35
+ logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
36
+ self.time += ms
37
+ self.layer_count += 1
38
+
39
+
40
+ class RuntimeProfiler(object):
41
+
42
+ def __init__(self):
43
+ self.timing_cache = None
44
+
45
+ def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
46
+ is_plugin = layer.type == trt.LayerType.PLUGIN_V2
47
+ if is_plugin and len(layer_attrs) > 0:
48
+ plugin_info = get_plugin_info(
49
+ get_trt_network(layer),
50
+ layer.name,
51
+ )
52
+ new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
53
+ layer_attrs = {"plugin": new_plugin}
54
+ graph, output_mapping = get_per_layer_graph(layer, shapes, values,
55
+ layer_attrs)
56
+ graph._io_buffer_mapping = io_buffer_mapping
57
+ network = graph.as_trt()
58
+ if network.num_outputs > 0 and np.all([
59
+ network.get_output(i).is_shape_tensor
60
+ for i in range(network.num_outputs)
61
+ ]):
62
+ return 0.0
63
+ for proxy_output, output in output_mapping.items():
64
+ shapes[proxy_output] = shapes[output]
65
+ if not self.timing_cache:
66
+ self.timing_cache = network.builder.create_builder_config(
67
+ ).create_timing_cache(b"")
68
+ runner = graph.get_runner(
69
+ shapes,
70
+ values,
71
+ timing_cache=self.timing_cache,
72
+ )
73
+ context = runner.session.context
74
+ context.profiler = LayerProfiler()
75
+ runner.run()
76
+ profiler_time_first_run = context.profiler.time
77
+ runner.run()
78
+ return (context.profiler.time - profiler_time_first_run) * 1000.0
79
+
80
+ def runtime_profile(self, layer, layer_attrs, input_values, strategy,
81
+ device_mesh):
82
+ logger.debug(f"start to profile layer {layer.name}")
83
+ shapes = {}
84
+ values = {}
85
+ dtypes = {}
86
+ trt_layer = layer.as_trt()
87
+
88
+ sharding_sequences = ()
89
+ for i in range(layer.num_inputs):
90
+ input = trt_layer.get_input(i)
91
+ if input is not None:
92
+ shapes[input.name] = strategy.sharding_specs[
93
+ f'input{i}'].get_sharded_shape_per_device()
94
+ dtypes[input.name] = input.dtype
95
+ sharding_sequences += (str(
96
+ strategy.sharding_specs[f"input{i}"].sharding_sequence), )
97
+ if i in input_values:
98
+ values[input.name] = input_values[i]
99
+ else:
100
+ value = layer.get_input(i).value
101
+ if value is not None:
102
+ values[input.name] = value
103
+ else:
104
+ sharding_sequences += (None, )
105
+
106
+ for i in range(layer.num_outputs):
107
+ output = trt_layer.get_output(i)
108
+ if f'output{i}' in strategy.communication_actions:
109
+ shapes[output.name] = strategy.communication_actions[
110
+ f'output{i}'].sharding_spec.get_sharded_shape_per_device()
111
+ else:
112
+ shapes[output.name] = strategy.sharding_specs[
113
+ f'output{i}'].get_sharded_shape_per_device()
114
+ dtypes[output.name] = output.dtype
115
+ sharding_sequences += (str(
116
+ strategy.sharding_specs[f"output{i}"].sharding_sequence), )
117
+ data_key = get_cache_key(
118
+ trt_layer,
119
+ shapes,
120
+ values,
121
+ dtypes=dtypes,
122
+ updated_attrs=layer_attrs,
123
+ )
124
+ data_key += (sharding_sequences, )
125
+ elapsed_time = device_mesh.prof_database.query(
126
+ device_mesh.cluster_key,
127
+ data_key,
128
+ )
129
+ if elapsed_time:
130
+ logger.debug(
131
+ f'runtime profiling cache hit {data_key}: {elapsed_time} us')
132
+ return elapsed_time
133
+ with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
134
+ elapsed_time = self._profile(
135
+ layer.as_trt(),
136
+ layer_attrs,
137
+ shapes,
138
+ values,
139
+ layer.graph._io_buffer_mapping,
140
+ )
141
+ logger.debug(
142
+ f'runtime profiling cache miss {data_key}: {elapsed_time} us')
143
+
144
+ device_mesh.prof_database.update(
145
+ device_mesh.cluster_key,
146
+ data_key,
147
+ (elapsed_time, strategy.alpha_beta_cost),
148
+ )
149
+
150
+ return elapsed_time
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/shape_info.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Dict, List, Set
4
+
5
+ import numpy as np
6
+ import tensorrt as trt
7
+ import torch
8
+
9
+ from tensorrt_llm._common import _is_building
10
+ from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str,
11
+ trt_dtype_to_torch)
12
+ from tensorrt_llm.logger import logger
13
+
14
+ from .pipeline_graph import PipelineGraph
15
+ from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids,
16
+ set_trt_network, to_base_class_layer, to_subclass_layer,
17
+ to_trt_weights)
18
+
19
+
20
+ class ShapeType(Enum):
21
+ MIN = 0
22
+ OPT = 1
23
+ MAX = 2
24
+
25
+
26
+ _trt_to_type_dict = {
27
+ trt.int64: int,
28
+ trt.bool: bool,
29
+ }
30
+
31
+
32
+ def get_shape_layers(trt_network):
33
+ shape_layers = set()
34
+ for i in range(trt_network.num_layers):
35
+ layer = trt_network.get_layer(i)
36
+ if (layer.num_inputs > 0 and np.all([
37
+ layer.get_input(j).is_shape_tensor
38
+ for j in range(layer.num_inputs)
39
+ if layer.get_input(j) is not None
40
+ ])) or (layer.num_outputs > 0 and np.all([
41
+ layer.get_output(j).is_shape_tensor
42
+ for j in range(layer.num_outputs)
43
+ ])):
44
+ shape_layers.add(layer.name)
45
+ return shape_layers
46
+
47
+
48
+ def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids):
49
+ layers = set()
50
+ shape_tensors = set()
51
+ for layer_id in reversed(sorted_layer_ids):
52
+ layer = trt_network.get_layer(layer_id)
53
+ in_shape_network = False
54
+ if layer.name in shape_layers:
55
+ in_shape_network = True
56
+ else:
57
+ for j in range(layer.num_outputs):
58
+ output = layer.get_output(j)
59
+ if output.name in shape_tensors:
60
+ in_shape_network = True
61
+ break
62
+ if in_shape_network:
63
+ layers.add(layer.name)
64
+ for j in range(layer.num_inputs):
65
+ input = layer.get_input(j)
66
+ if input is not None:
67
+ shape_tensors.add(input.name)
68
+ return layers
69
+
70
+
71
+ def get_shape_network(trt_network,
72
+ shapes,
73
+ values,
74
+ sorted_layer_ids,
75
+ profile=None,
76
+ shape_type: ShapeType = ShapeType.OPT):
77
+ shape_layers = get_shape_layers(trt_network)
78
+ layers_in_shape_network = get_layers_in_shape_network(
79
+ trt_network, shape_layers, sorted_layer_ids)
80
+ shape_graph = PipelineGraph.create_graph()
81
+ shape_network = shape_graph.as_trt()
82
+ shape_builder = shape_network.builder
83
+ shape_profile = shape_builder.create_optimization_profile()
84
+ for i in range(trt_network.num_inputs):
85
+ input = trt_network.get_input(i)
86
+ shapes[input.name] = input.shape
87
+ new_input = shape_graph.add_input(input)
88
+ if profile is not None:
89
+ if -1 in input.shape:
90
+ shape = profile.get_shape(input.name)
91
+ shape = shape[shape_type.value]
92
+ shapes[input.name] = shape
93
+ new_input.raw_shape = shape
94
+ if input.is_shape_tensor:
95
+ shape_values = profile.get_shape_input(input.name)
96
+ value = shape_values[shape_type.value]
97
+ values[input.name] = value
98
+ shape_profile.set_shape_input(input.name, value, value, value)
99
+ output_mapping = {}
100
+ for layer_id in sorted_layer_ids:
101
+ layer = trt_network.get_layer(layer_id)
102
+ if layer.name in shape_layers:
103
+ new_layer = shape_graph.add_layer(layer)
104
+ for i in range(layer.num_outputs):
105
+ output = layer.get_output(i)
106
+ if output.dtype == trt.DataType.BOOL:
107
+ proxy_layer = shape_network.add_cast(
108
+ new_layer.as_trt().get_output(i),
109
+ trt.DataType.INT32,
110
+ )
111
+ proxy_output = proxy_layer.get_output(0)
112
+ shape_graph.register_layer(proxy_layer)
113
+ shape_graph.add_output_shape(proxy_output)
114
+ output_mapping[proxy_output.name] = (output.name,
115
+ output.dtype)
116
+ else:
117
+ shape_graph.add_output_shape(output)
118
+ elif layer.name in layers_in_shape_network:
119
+ if layer.type == trt.LayerType.CONSTANT:
120
+ shape_graph.add_input(layer.get_output(0))
121
+ else:
122
+ shape_graph.add_layer(layer)
123
+ return shape_network, shape_profile, shape_layers, output_mapping
124
+
125
+
126
+ def get_per_layer_graph(
127
+ layer,
128
+ shapes,
129
+ values,
130
+ updated_attrs=None,
131
+ is_shape_io: bool = None,
132
+ ):
133
+ graph = PipelineGraph.create_graph()
134
+ network = graph.as_trt()
135
+ is_shape_layer = layer.num_inputs != 0
136
+ for i in range(layer.num_inputs):
137
+ input = layer.get_input(i)
138
+ if input is not None:
139
+ shape = shapes[input.name]
140
+ if (values.get(input.name) is not None
141
+ and not isinstance(values[input.name], torch.Tensor)):
142
+ value = values[input.name]
143
+ weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype))
144
+ weights = to_trt_weights(weights)
145
+ input_layer = network.add_constant(shape, weights)
146
+ new_input = input_layer.get_output(0)
147
+ new_input.name = input.name
148
+ graph.register_layer(input_layer)
149
+ elif graph.get_input(input.name) is None:
150
+ new_input = graph.add_input(input)
151
+ new_input.raw_shape = shapes[input.name]
152
+ is_shape_layer = False
153
+ new_layer = graph.add_layer(
154
+ layer,
155
+ updated_attrs=updated_attrs,
156
+ )
157
+ output_mapping = {}
158
+ if layer.type == trt.LayerType.SHAPE:
159
+ is_shape_layer = True
160
+ if layer.num_inputs == 0:
161
+ is_shape_layer = False
162
+ if is_shape_io is not None:
163
+ is_shape_layer = is_shape_io
164
+ for i in range(layer.num_outputs):
165
+ output = layer.get_output(i)
166
+ value = values.get(output.name)
167
+ if value is not None and isinstance(value, torch.Tensor):
168
+ is_output_shape = False
169
+ elif is_shape_layer:
170
+ is_output_shape = True
171
+ else:
172
+ is_output_shape = False
173
+ if is_output_shape:
174
+ if output.dtype == trt.DataType.BOOL:
175
+ proxy_layer = network.add_cast(
176
+ new_layer.as_trt().get_output(i),
177
+ trt.DataType.INT32,
178
+ )
179
+ proxy_output = proxy_layer.get_output(0)
180
+ graph.register_layer(proxy_layer)
181
+ output_mapping[proxy_output.name] = (output.name, output.dtype)
182
+ output = proxy_output
183
+ graph.add_output_shape(output)
184
+ else:
185
+ graph.add_output(output)
186
+ return graph, output_mapping
187
+
188
+
189
+ @_is_building
190
+ def infer_shapes(network, shapes, values, profile=None):
191
+ if network.num_outputs == 0:
192
+ return
193
+ builder = network.builder
194
+ config = builder.create_builder_config()
195
+ config.builder_optimization_level = 0
196
+ config.flags = get_builder_flags()
197
+ profile = profile or builder.create_optimization_profile()
198
+ config.add_optimization_profile(profile)
199
+ plan = builder.build_serialized_network(network, config)
200
+ if plan is None:
201
+ raise RuntimeError(
202
+ 'Engine building failed when inferring shapes, please check the error log.'
203
+ )
204
+ runtime = trt.Runtime(logger.trt_logger)
205
+ engine = runtime.deserialize_cuda_engine(plan)
206
+ context = engine.create_execution_context()
207
+ for i in range(network.num_inputs):
208
+ input = network.get_input(i)
209
+ if input.is_shape_tensor:
210
+ value = values[input.name]
211
+ context.set_shape_input(engine[input.name], value)
212
+ for i in range(network.num_outputs):
213
+ output = network.get_output(i)
214
+ shape = context.get_tensor_shape(output.name)
215
+ shapes[output.name] = shape
216
+ if output.is_shape_tensor:
217
+ if shape == [0]:
218
+ values[output.name] = []
219
+ else:
220
+ if shape == []:
221
+ shape = [1]
222
+ value = torch.empty(
223
+ list(shape),
224
+ dtype=trt_dtype_to_torch(output.dtype),
225
+ device="cpu",
226
+ )
227
+ values[output.name] = value
228
+ context.set_tensor_address(output.name, value.data_ptr())
229
+ context.infer_shapes()
230
+ assert context.all_binding_shapes_specified
231
+ for i in range(network.num_outputs):
232
+ output = network.get_output(i)
233
+ if isinstance(values.get(output.name), torch.Tensor):
234
+ values[output.name] = values[output.name].tolist()
235
+
236
+
237
+ @dataclass
238
+ class ShapeInfo:
239
+ shapes: Dict[str, trt.Dims]
240
+ values: Dict[str, List[int]]
241
+ shape_layers: Set[str]
242
+ max_shapes: Dict[str, trt.Dims] = None
243
+
244
+
245
+ def set_constant_value(layer, values):
246
+ to_subclass_layer(layer)
247
+ output_name = layer.get_output(0).name
248
+ weights = layer.weights
249
+ if isinstance(weights, trt.Weights):
250
+ weights = weights.numpy()
251
+ values[output_name] = list(weights)
252
+ to_base_class_layer(layer)
253
+
254
+
255
+ def infer_per_layer_shapes(
256
+ layer: trt.ILayer,
257
+ shapes,
258
+ values,
259
+ cache=None,
260
+ is_shape_io=False,
261
+ ):
262
+ if layer.type == trt.LayerType.CONSTANT:
263
+ to_subclass_layer(layer)
264
+ output_name = layer.get_output(0).name
265
+ shape = layer.shape
266
+ shapes[output_name] = shape
267
+ if is_shape_io:
268
+ set_constant_value(layer, values)
269
+ to_base_class_layer(layer)
270
+ return
271
+ elif layer.type == trt.LayerType.SHAPE:
272
+ input_name = layer.get_input(0).name
273
+ output_name = layer.get_output(0).name
274
+ shape = [*shapes[input_name]]
275
+ shapes[output_name] = trt.Dims([len(shape)])
276
+ values[output_name] = shape
277
+ return
278
+ if cache is not None:
279
+ cache_key = get_cache_key(layer, shapes, values)
280
+ if cache_key in cache:
281
+ output_shapes, output_values = cache[cache_key]
282
+ for i in range(layer.num_outputs):
283
+ output = layer.get_output(i)
284
+ shapes[output.name] = output_shapes[i]
285
+ if output_values[i] is not None:
286
+ values[output.name] = output_values[i]
287
+ return
288
+ graph, output_mapping = get_per_layer_graph(layer, shapes, values)
289
+ dtypes = [
290
+ trt_dtype_to_str(layer.get_input(i).dtype)
291
+ for i in range(layer.num_inputs)
292
+ ]
293
+ layer_info = (f"type={cache_key[0]}, "
294
+ f"attrs={dict(cache_key[1])}, "
295
+ f"dtypes={dtypes}, "
296
+ f"shapes={list(cache_key[2])}, "
297
+ f"values={list(cache_key[3])}")
298
+ logger.debug(f"infer shapes for layer {layer.name} ({layer_info})")
299
+ try:
300
+ infer_shapes(graph.as_trt(), shapes, values)
301
+ except RuntimeError as e:
302
+ raise RuntimeError(
303
+ f"infer shapes failed for layer {layer.name} ({layer_info})") from e
304
+ for proxy_output, (output, dtype) in output_mapping.items():
305
+ shapes[output] = shapes[proxy_output]
306
+ del shapes[proxy_output]
307
+ if proxy_output in values:
308
+ values[output] = [
309
+ *map(_trt_to_type_dict[dtype], values[proxy_output])
310
+ ]
311
+ del values[proxy_output]
312
+ if cache is not None:
313
+ logger.debug(
314
+ f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}"
315
+ )
316
+ output_shapes = []
317
+ output_values = []
318
+ for i in range(layer.num_outputs):
319
+ output = layer.get_output(i)
320
+ output_shapes.append(shapes[output.name])
321
+ output_values.append(values.get(output.name))
322
+ cache[cache_key] = (output_shapes, output_values)
323
+
324
+
325
+ def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT):
326
+ shapes = {}
327
+ values = {}
328
+ sorted_layer_ids = get_sorted_layer_ids(trt_network)
329
+ infer_shape_layers = False
330
+
331
+ shape_network, shape_profile, shape_layers, output_mapping = get_shape_network(
332
+ trt_network,
333
+ shapes,
334
+ values,
335
+ sorted_layer_ids,
336
+ profile=profile,
337
+ shape_type=shape_type)
338
+ try:
339
+ infer_shapes(shape_network, shapes, values, shape_profile)
340
+ for proxy_output, (output, dtype) in output_mapping.items():
341
+ shapes[output] = shapes[proxy_output]
342
+ values[output] = [
343
+ *map(_trt_to_type_dict[dtype], values[proxy_output])
344
+ ]
345
+ del shapes[proxy_output]
346
+ del values[proxy_output]
347
+ except RuntimeError:
348
+ infer_shape_layers = True
349
+
350
+ cache = {}
351
+ for layer_id in sorted_layer_ids:
352
+ layer = trt_network.get_layer(layer_id)
353
+ is_shape_io = layer.name in shape_layers
354
+ if is_shape_io and not infer_shape_layers:
355
+ continue
356
+ set_trt_network(layer, trt_network)
357
+ infer_per_layer_shapes(layer,
358
+ shapes,
359
+ values,
360
+ cache,
361
+ is_shape_io=is_shape_io)
362
+ return ShapeInfo(shapes, values, shape_layers)
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/simplifier.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Dict, List, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from tensorrt_llm.network import Network
10
+
11
+ from .config import AutoParallelConfig
12
+ from .device_mesh import PhysicalDeviceMesh
13
+ from .pipeline_graph import PipelineGraph
14
+ from .shape_info import ShapeInfo, ShapeType, get_shape_info
15
+ from .tensor_parallel.p2p_node import P2PType
16
+ from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
17
+
18
+
19
+ class StageType(Enum):
20
+ START = 0
21
+ BLOCK = 1
22
+ END = 2
23
+
24
+
25
+ class BuildingBlock:
26
+
27
+ def __init__(self, graph, layer_range) -> None:
28
+ self.graph = graph
29
+ self.layer_range = layer_range
30
+ self.network = graph.as_trt()
31
+ self.owned_inputs = {}
32
+ self.is_edges_collected = False
33
+ self.intra_edges = []
34
+ self.src_inter_edges = []
35
+ self.dst_inter_edges = []
36
+ self.relative_src_inter_edges = []
37
+ self.relative_dst_inter_edges = []
38
+ self.relative_inter_edges = set()
39
+ self.edge_hash = None
40
+ self.outputs = None
41
+ self.type_id = -1
42
+ self.block_id = -1
43
+ self.p2p_type = None
44
+ self.is_superset = False
45
+ self.is_subset = False
46
+ self.sorted_layer_ids = []
47
+
48
+ def collect_edges(self):
49
+ if self.is_edges_collected:
50
+ return
51
+ for layer_index in self.layer_range:
52
+ trt_layer = self.network.get_layer(layer_index)
53
+ layer = self.graph.get_layer(trt_layer.name)
54
+ layer_offset = layer.index - self.layer_range.start
55
+ for input_index, input in enumerate(layer.inputs):
56
+ if input is not None:
57
+ if input.is_graph_input:
58
+ is_owned = input.graph_input_index in self.owned_inputs
59
+ if not is_owned and np.all([
60
+ layer.index in self.layer_range or np.all([
61
+ output.as_trt().is_shape_tensor
62
+ for output in layer.outputs
63
+ ]) for layer, _ in input.consumers
64
+ ]):
65
+ self.owned_inputs[input.graph_input_index] = len(
66
+ self.owned_inputs)
67
+ is_owned = True
68
+ if is_owned:
69
+ self.intra_edges.append(
70
+ (-1, self.owned_inputs[input.graph_input_index],
71
+ layer_offset, input_index))
72
+ else:
73
+ self.dst_inter_edges.append(
74
+ (-1, input.graph_input_index, layer_offset,
75
+ input_index))
76
+ else:
77
+ src_layer_index = input.producer.index
78
+ if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
79
+ self.dst_inter_edges.append(
80
+ (src_layer_index, input.output_index,
81
+ layer_offset, input_index))
82
+ else:
83
+ src_layer_offset = src_layer_index - self.layer_range.start
84
+ self.intra_edges.append(
85
+ (src_layer_offset, input.output_index,
86
+ layer_offset, input_index))
87
+ for output_index, output in enumerate(layer.outputs):
88
+ for dst_layer, dst_input_index in output.consumers:
89
+ dst_layer_index = dst_layer.index
90
+ if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
91
+ self.src_inter_edges.append(
92
+ (layer_offset, output_index, dst_layer_index,
93
+ dst_input_index))
94
+ self.edge_hash = tuple(self.intra_edges)
95
+ self.outputs = sorted(
96
+ set((edge[0], edge[1]) for edge in self.src_inter_edges))
97
+ self.is_edges_collected = True
98
+
99
+ def collect_relative_inter_edges(self, layer_to_block):
100
+ self.collect_edges()
101
+ for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
102
+ if src_layer_index in layer_to_block:
103
+ src_block = layer_to_block[src_layer_index]
104
+ src_layer_offset = src_layer_index - src_block.layer_range.start
105
+ dst = (self.type_id, dst_layer_index, dst_input_index)
106
+ self.relative_dst_inter_edges.append(
107
+ (src_block.type_id, src_layer_offset, src_output_index,
108
+ *dst))
109
+ else:
110
+ self.relative_dst_inter_edges.append(
111
+ (-1, src_layer_index, src_output_index, self.type_id,
112
+ dst_layer_index, dst_input_index))
113
+ self.relative_inter_edges = set(self.relative_dst_inter_edges +
114
+ self.outputs)
115
+
116
+ def get_input_names(self):
117
+ self.collect_edges()
118
+ input_tensor_names = []
119
+ for edge in self.dst_inter_edges:
120
+ layer_index = edge[0]
121
+ output_index = edge[1]
122
+ if layer_index == -1:
123
+ tensor_name = self.network.get_input(output_index).name
124
+ else:
125
+ tensor_name = self.network.get_layer(layer_index).get_output(
126
+ output_index).name
127
+ input_tensor_names.append(tensor_name)
128
+ return input_tensor_names
129
+
130
+ def get_input_mapping(self, last_blocks):
131
+ input_mapping = {}
132
+ for tensor_name, relative_edge in zip(self.get_input_names(),
133
+ self.relative_dst_inter_edges):
134
+ type_id = relative_edge[0]
135
+ output_index = relative_edge[2]
136
+ if type_id >= 0:
137
+ last_block = last_blocks[type_id]
138
+ layer_offset = relative_edge[1]
139
+ mapped_layer_index = last_block.layer_range.start + layer_offset
140
+ mapped_tensor_name = self.network.get_layer(
141
+ mapped_layer_index).get_output(output_index).name
142
+ input_mapping[tensor_name] = mapped_tensor_name
143
+ else:
144
+ input_mapping[tensor_name] = tensor_name
145
+ return input_mapping
146
+
147
+
148
+ @dataclass
149
+ class GraphMapping:
150
+ layer_mapping: Dict[int, int] = None
151
+ block_mapping: Dict[int, int] = None
152
+ p2p_types: Dict[int, P2PType] = None
153
+ p2p_tensors: Dict[int, List[str]] = None
154
+ block_to_stage: Dict[int, int] = None
155
+ same_spec_layer_mapping: Dict[str, str] = None
156
+
157
+
158
+ @dataclass
159
+ class GraphConfig:
160
+ num_micro_batches: int = 1
161
+ num_blocks: int = 1
162
+ num_stages: int = 1
163
+ has_cross_device: bool = False
164
+ has_cross_host: bool = False
165
+ graph_mapping: GraphMapping = None
166
+ phy_mesh: PhysicalDeviceMesh = None
167
+ stage_phy_meshes: List[PhysicalDeviceMesh] = None
168
+
169
+
170
+ class Simplifier:
171
+
172
+ def __init__(self, network: Network, config: AutoParallelConfig):
173
+ self.config = config
174
+ self.sharded_io_allowlist = config.sharded_io_allowlist
175
+ self.same_buffer_io = config.same_buffer_io
176
+ self.same_spec_io = config.same_spec_io.copy()
177
+ for key, value in self.same_buffer_io.items():
178
+ if key not in self.same_spec_io:
179
+ self.same_spec_io[key] = value
180
+
181
+ self.llm_network = network
182
+ self.network = network.trt_network
183
+ self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
184
+ self.graph = self.get_graph()
185
+ self.init_layer_hash()
186
+
187
+ module_tree = self.get_module_tree()
188
+ building_blocks = self.collect_building_blocks(module_tree)
189
+ blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
190
+ self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
191
+ blocks_by_module_hash)
192
+ self.layer_to_block = self.get_layer_to_block()
193
+ self.blocks = self.get_all_blocks()
194
+ self.backbone_blocks = self.get_backbone_blocks()
195
+ self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
196
+ self.graph_for_shape = self.create_simplified_graph_for_shape()
197
+ self.shape_info = None
198
+ self.num_micro_batches = None
199
+
200
+ def infer_shapes(self, num_micro_batches):
201
+ if self.num_micro_batches == num_micro_batches:
202
+ return
203
+ with silent_trt_logger():
204
+ self.shape_info = self.get_full_shape_info(num_micro_batches)
205
+ self.graph.assign_shapes(self.shape_info)
206
+ self.num_micro_batches = num_micro_batches
207
+
208
+ def list_all_num_micro_batches(self):
209
+ opt_batch_size = self.get_opt_batch_size()
210
+ candidates = []
211
+ for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
212
+ if opt_batch_size % num_micro_batches == 0:
213
+ candidates.append(num_micro_batches)
214
+ return candidates
215
+
216
+ def get_graph(self):
217
+ graph = PipelineGraph.from_trt(self.network)
218
+ graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
219
+ graph._io_buffer_mapping
220
+ for input in graph.inputs:
221
+ input_name = input.name
222
+ for pattern, repl in self.same_buffer_io.items():
223
+ if re.match(pattern, input_name):
224
+ output_name = re.sub(pattern, repl, input_name)
225
+ output = graph.get_output(output_name)
226
+ if output is not None:
227
+ graph._io_buffer_mapping[output_name] = input_name
228
+ return graph
229
+
230
+ def get_opt_batch_size(self):
231
+ input_tensors = self.llm_network._inputs
232
+ num_profiles = len(list(input_tensors.values())[0].profiles)
233
+ opt_batch_sizes = []
234
+ for i in range(num_profiles):
235
+ for input_tensor in input_tensors.values():
236
+ shape_profile = input_tensor.profiles[i]
237
+ opt_shape = shape_profile.opt
238
+ for j in range(len(input_tensor.shape)):
239
+ name = input_tensor.trt_tensor.get_dimension_name(j)
240
+ if name == 'batch_size':
241
+ opt_batch_sizes.append(opt_shape[j])
242
+ return min(opt_batch_sizes)
243
+
244
+ def get_module_hash(self, layer_range):
245
+ module_hash = ()
246
+ for i in layer_range:
247
+ assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
248
+ layer_name = self.network.get_layer(i).name
249
+ layer = self.graph.get_layer(layer_name)
250
+ module_hash += (layer.attrs["hash"], )
251
+ return module_hash
252
+
253
+ def get_network_hash(self) -> str:
254
+ return str(self.get_module_hash(range(self.network.num_layers)))
255
+
256
+ def collect_building_blocks(self, module_tree):
257
+ building_blocks = {}
258
+ queue = []
259
+ for tree in module_tree["children"].values():
260
+ queue.append(tree)
261
+ while len(queue) > 0:
262
+ while len(queue) > 0:
263
+ tree = queue.pop(0)
264
+ module_name = tree["name"]
265
+ if module_name is None:
266
+ for child in tree["children"].values():
267
+ queue.append(child)
268
+ continue
269
+ layer_range = self.module_to_layer_range_map[module_name]
270
+ module_hash = self.get_module_hash(layer_range)
271
+ if module_hash in building_blocks:
272
+ building_blocks[module_hash].append(tree)
273
+ else:
274
+ building_blocks[module_hash] = [tree]
275
+ for module_hash in [*building_blocks.keys()]:
276
+ if len(building_blocks[module_hash]) == 1:
277
+ tree = building_blocks[module_hash][0]
278
+ for child in tree["children"].values():
279
+ queue.append(child)
280
+ del building_blocks[module_hash]
281
+ blocks_by_module_hash = {
282
+ module_hash: [
283
+ BuildingBlock(self.graph,
284
+ self.module_to_layer_range_map[tree["name"]])
285
+ for tree in trees
286
+ ]
287
+ for module_hash, trees in building_blocks.items()
288
+ }
289
+ building_blocks = []
290
+ for block_list in blocks_by_module_hash.values():
291
+ for block in block_list:
292
+ building_blocks.append(block)
293
+ building_blocks = sorted(building_blocks,
294
+ key=lambda x: x.layer_range.start)
295
+ if len(building_blocks) >= 2:
296
+ for block, next_block in zip(building_blocks[:-1],
297
+ building_blocks[1:]):
298
+ block.layer_range = range(block.layer_range.start,
299
+ next_block.layer_range.start)
300
+ return building_blocks
301
+
302
+ def get_all_blocks(self):
303
+ building_blocks = []
304
+ for block_list in self.blocks_by_edge_hash.values():
305
+ for block in block_list:
306
+ building_blocks.append(block)
307
+ building_blocks = sorted(building_blocks,
308
+ key=lambda x: x.layer_range.start)
309
+ all_blocks = []
310
+ current_layer_index = 0
311
+ block_id = 0
312
+ for block in building_blocks:
313
+ assert current_layer_index <= block.layer_range.start
314
+ if current_layer_index < block.layer_range.start:
315
+ new_block = BuildingBlock(
316
+ self.graph,
317
+ range(current_layer_index, block.layer_range.start))
318
+ new_block.block_id = block_id
319
+ block_id += 1
320
+ all_blocks.append(new_block)
321
+ block.block_id = block_id
322
+ block_id += 1
323
+ all_blocks.append(block)
324
+ current_layer_index = block.layer_range.stop
325
+ if current_layer_index < self.graph.num_layers:
326
+ new_block = BuildingBlock(
327
+ self.graph, range(current_layer_index, self.graph.num_layers))
328
+ new_block.block_id = block_id
329
+ all_blocks.append(new_block)
330
+ sorted_layer_ids = get_sorted_layer_ids(self.network)
331
+ for block in all_blocks:
332
+ block.collect_relative_inter_edges(self.layer_to_block)
333
+ for layer_id in sorted_layer_ids:
334
+ if layer_id in block.layer_range:
335
+ block.sorted_layer_ids.append(layer_id)
336
+ return all_blocks
337
+
338
+ def get_backbone_blocks(self):
339
+ sorted_blocks = sorted(
340
+ self.blocks_by_edge_hash.values(),
341
+ key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
342
+ )
343
+ if len(sorted_blocks) == 0:
344
+ return []
345
+ else:
346
+ return sorted_blocks[-1]
347
+
348
+ def get_blocks_by_module_hash(self, blocks):
349
+ blocks_by_module_hash = {}
350
+ for block in blocks:
351
+ module_hash = self.get_module_hash(block.layer_range)
352
+ if module_hash not in blocks_by_module_hash:
353
+ blocks_by_module_hash[module_hash] = []
354
+ blocks_by_module_hash[module_hash].append(block)
355
+ for module_hash in [*blocks_by_module_hash.keys()]:
356
+ if len(blocks_by_module_hash[module_hash]) == 1:
357
+ del blocks_by_module_hash[module_hash]
358
+ return blocks_by_module_hash
359
+
360
+ def get_module_tree(self):
361
+ module_tree = {"children": {}, "name": None}
362
+ for module_name in self.module_to_layer_range_map.keys():
363
+ full_name = module_name.split('.')
364
+ current_tree = module_tree["children"]
365
+ for depth, name in enumerate(full_name):
366
+ if name not in current_tree:
367
+ current_tree[name] = {"children": {}, "name": None}
368
+ if depth == len(full_name) - 1:
369
+ current_tree[name]["name"] = module_name
370
+ else:
371
+ current_tree = current_tree[name]["children"]
372
+ return module_tree
373
+
374
+ def get_blocks_by_edge_hash(self, blocks_by_module_hash):
375
+ blocks_by_edge_hash = {}
376
+ for block_list in blocks_by_module_hash.values():
377
+ for block in block_list:
378
+ block.collect_edges()
379
+ edge_hash = block.edge_hash
380
+ if edge_hash not in blocks_by_edge_hash:
381
+ blocks_by_edge_hash[edge_hash] = []
382
+ blocks_by_edge_hash[edge_hash].append(block)
383
+ for edge_hash in [*blocks_by_edge_hash.keys()]:
384
+ if len(blocks_by_edge_hash[edge_hash]) == 1:
385
+ del blocks_by_edge_hash[edge_hash]
386
+ else:
387
+ block_list = blocks_by_edge_hash[edge_hash]
388
+ blocks_by_edge_hash[edge_hash] = sorted(
389
+ block_list, key=lambda x: x.layer_range.start)
390
+ for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
391
+ for block in block_list:
392
+ block.type_id = type_id
393
+ return blocks_by_edge_hash
394
+
395
+ def get_layer_to_block(self):
396
+ layer_to_block = {}
397
+ for block_list in self.blocks_by_edge_hash.values():
398
+ for block in block_list:
399
+ for layer_index in block.layer_range:
400
+ layer_to_block[layer_index] = block
401
+ return layer_to_block
402
+
403
+ def clean_blocks(self):
404
+ for block in self.blocks:
405
+ block.p2p_type = None
406
+ block.is_superset = False
407
+ block.is_subset = False
408
+
409
+ def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
410
+ graph_config: GraphConfig):
411
+ if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
412
+ return
413
+ assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
414
+ block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
415
+
416
+ for block in self.backbone_blocks:
417
+ block.p2p_type = None
418
+ for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
419
+ next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
420
+ last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
421
+ next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
422
+ )[0]
423
+ num_devices_per_host = phy_mesh.num_devices_per_host
424
+ next_block = self.backbone_blocks[(stage_index + 1) *
425
+ block_per_stage]
426
+ if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
427
+ next_block.p2p_type = P2PType.CROSS_HOST
428
+ graph_config.has_cross_host = True
429
+ else:
430
+ next_block.p2p_type = P2PType.CROSS_DEVICE
431
+ graph_config.has_cross_device = True
432
+
433
+ def get_graph_mapping(self):
434
+ layer_mapping = {}
435
+ block_mapping = {}
436
+ p2p_types = {}
437
+ p2p_tensors = {}
438
+ for block_list in self.blocks_by_edge_hash.values():
439
+ superset_blocks = []
440
+ superset_block_index = {}
441
+ for block in block_list:
442
+ block_added = False
443
+ for index, superset_block in enumerate(list(superset_blocks)):
444
+ if block.p2p_type == superset_block.p2p_type:
445
+ if block.relative_inter_edges.issubset(
446
+ superset_block.relative_inter_edges):
447
+ block.is_subset = True
448
+ block.is_superset = False
449
+ superset_block_index[id(block)] = index
450
+ block_added = True
451
+ break
452
+ elif superset_block.relative_inter_edges.issubset(
453
+ block.relative_inter_edges):
454
+ superset_block.is_subset = True
455
+ superset_block.is_superset = False
456
+ block.is_subset = False
457
+ block.is_superset = True
458
+ superset_blocks[index] = block
459
+ superset_block_index[id(block)] = index
460
+ block_added = True
461
+ break
462
+ if not block_added:
463
+ block.is_subset = False
464
+ block.is_superset = True
465
+ superset_blocks.append(block)
466
+ superset_block_index[id(block)] = len(superset_blocks) - 1
467
+ for block in block_list:
468
+ assert not (block.is_subset and block.is_superset)
469
+ if block.is_subset:
470
+ superset_block = superset_blocks[superset_block_index[id(
471
+ block)]]
472
+ block_mapping[block.block_id] = superset_block.block_id
473
+ owned_inputs = map(
474
+ lambda x: x[0],
475
+ sorted(block.owned_inputs.items(), key=lambda x: x[1]))
476
+ superset_owned_inputs = map(
477
+ lambda x: x[0],
478
+ sorted(superset_block.owned_inputs.items(),
479
+ key=lambda x: x[1]))
480
+ for from_input_id, to_input_id in zip(
481
+ owned_inputs, superset_owned_inputs):
482
+ from_input_name = self.network.get_input(
483
+ from_input_id).name
484
+ to_input_name = self.network.get_input(to_input_id).name
485
+ layer_mapping[from_input_name] = to_input_name
486
+ for from_layer_id, to_layer_id in zip(
487
+ block.layer_range, superset_block.layer_range):
488
+ from_layer = self.network.get_layer(from_layer_id)
489
+ to_layer = self.network.get_layer(to_layer_id)
490
+ layer_mapping[from_layer.name] = to_layer.name
491
+ for i in range(from_layer.num_outputs):
492
+ from_output = from_layer.get_output(i)
493
+ if from_output.is_network_output:
494
+ to_output = to_layer.get_output(i)
495
+ layer_mapping[from_output.name] = to_output.name
496
+ if block.p2p_type is not None:
497
+ p2p_types[block.block_id] = block.p2p_type
498
+ p2p_tensors[block.block_id] = [
499
+ *set(block.get_input_names())
500
+ ]
501
+ for from_name, to_name in zip(
502
+ block.get_input_names(),
503
+ superset_block.get_input_names()):
504
+ layer_mapping[
505
+ f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
506
+ stage_id = 0
507
+ block_to_stage = {}
508
+ for block in self.blocks:
509
+ if block.p2p_type is not None:
510
+ stage_id += 1
511
+ block_to_stage[block.block_id] = stage_id
512
+ return GraphMapping(
513
+ layer_mapping,
514
+ block_mapping,
515
+ p2p_types,
516
+ p2p_tensors,
517
+ block_to_stage,
518
+ )
519
+
520
+ def create_simplified_graph(self, graph_config: GraphConfig):
521
+ new_graph = PipelineGraph.create_graph()
522
+ new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
523
+ layer_mapping = graph_config.graph_mapping.layer_mapping
524
+
525
+ for i in range(self.network.num_inputs):
526
+ trt_input = self.network.get_input(i)
527
+ if trt_input.name not in layer_mapping:
528
+ new_graph.add_input(trt_input)
529
+
530
+ last_blocks = {}
531
+ same_spec_mapping = {}
532
+ same_spec_layer_mapping = {}
533
+ shape_mapping = {}
534
+ building_block_id = 0
535
+ same_spec_ids = {}
536
+ same_spec_count = 0
537
+ for block in self.blocks:
538
+ if not block.is_subset:
539
+ stage_type = None
540
+ if not block.is_superset:
541
+ if block.block_id == 0:
542
+ stage_type = StageType.START
543
+ elif block.block_id == len(self.blocks) - 1:
544
+ stage_type = StageType.END
545
+ input_mapping = block.get_input_mapping(last_blocks)
546
+ for from_name, to_name in [*input_mapping.items()]:
547
+ if to_name in same_spec_mapping:
548
+ input_mapping[from_name] = same_spec_mapping[to_name]
549
+ if to_name in layer_mapping:
550
+ input_mapping[from_name] = layer_mapping[to_name]
551
+ if block.is_superset and block.p2p_type is not None:
552
+ for from_name, to_name in [*input_mapping.items()]:
553
+ output_tensor = new_graph.get_tensor(to_name)
554
+ p2p_layer = new_graph.as_trt().add_identity(
555
+ output_tensor.as_trt())
556
+ p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
557
+ p2p_layer.metadata = p2p_layer.name
558
+ p2p_tensor = p2p_layer.get_output(0)
559
+ p2p_tensor.name = f"{p2p_layer.name}_output"
560
+ wrapped_layer = new_graph.register_layer(p2p_layer)
561
+ wrapped_layer.attrs[
562
+ "building_block_id"] = building_block_id
563
+ wrapped_layer.attrs["p2p_type"] = block.p2p_type
564
+ input_mapping[from_name] = p2p_tensor.name
565
+ shape_mapping[p2p_tensor.name] = from_name
566
+ building_block_id += 1
567
+ for i in block.sorted_layer_ids:
568
+ layer = self.network.get_layer(i)
569
+ wrapped_layer = new_graph.add_layer(
570
+ layer,
571
+ input_mapping=input_mapping,
572
+ )
573
+ wrapped_layer.attrs["building_block_id"] = building_block_id
574
+ wrapped_layer.attrs["stage_type"] = stage_type
575
+ if block.is_superset:
576
+ last_blocks[block.type_id] = block
577
+
578
+ if block.type_id in same_spec_ids:
579
+ same_spec_id = same_spec_ids[block.type_id]
580
+ update_same_spec_count = False
581
+ else:
582
+ same_spec_id = same_spec_count
583
+ same_spec_ids[block.type_id] = same_spec_id
584
+ update_same_spec_count = True
585
+ count = same_spec_id
586
+ for i, (layer_offset,
587
+ output_index) in enumerate(block.outputs):
588
+ layer = self.network.get_layer(block.layer_range.start +
589
+ layer_offset)
590
+ tensor_name = layer.get_output(output_index).name
591
+ output_tensor = new_graph.get_tensor(tensor_name)
592
+ same_spec_layer = new_graph.as_trt().add_identity(
593
+ output_tensor.as_trt())
594
+ same_spec_layer.name = f"{tensor_name}_same_spec"
595
+ same_spec_layer.metadata = same_spec_layer.name
596
+ same_spec_tensor = same_spec_layer.get_output(0)
597
+ same_spec_tensor.name = f"{same_spec_layer.name}_output"
598
+ wrapped_layer = new_graph.register_layer(
599
+ same_spec_layer)
600
+ wrapped_layer.attrs[
601
+ "building_block_id"] = building_block_id
602
+ wrapped_layer.attrs["same_spec_id"] = count
603
+ count += 1
604
+ same_spec_mapping[tensor_name] = same_spec_tensor.name
605
+ same_spec_layer_mapping[
606
+ same_spec_layer.name] = layer.name
607
+ shape_mapping[same_spec_tensor.name] = tensor_name
608
+ for i, graph_input_index in enumerate(
609
+ block.owned_inputs.keys()):
610
+ input_name = self.network.get_input(
611
+ graph_input_index).name
612
+ input_tensor = new_graph.get_input(input_name)
613
+ input_tensor.attrs["same_spec_id"] = count
614
+ count += 1
615
+ if update_same_spec_count:
616
+ same_spec_count = count
617
+ building_block_id += 1
618
+ graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
619
+
620
+ if len(self.backbone_blocks) >= 2:
621
+ start_block = self.backbone_blocks[0]
622
+ if start_block.is_subset:
623
+ start_block = self.blocks[graph_config.graph_mapping.
624
+ block_mapping[start_block.block_id]]
625
+ for i in start_block.layer_range:
626
+ layer_name = self.network.get_layer(i).name
627
+ layer = new_graph.get_layer(layer_name)
628
+ layer.attrs["in_start_block"] = True
629
+ end_block = self.backbone_blocks[-1]
630
+ if end_block.is_subset:
631
+ end_block = self.blocks[graph_config.graph_mapping.
632
+ block_mapping[end_block.block_id]]
633
+ for i in end_block.layer_range:
634
+ layer_name = self.network.get_layer(i).name
635
+ layer = new_graph.get_layer(layer_name)
636
+ layer.attrs["in_end_block"] = True
637
+ slowest_p2p_type = None
638
+ if graph_config.has_cross_host:
639
+ slowest_p2p_type = P2PType.CROSS_HOST
640
+ elif graph_config.has_cross_device:
641
+ slowest_p2p_type = P2PType.CROSS_DEVICE
642
+ if slowest_p2p_type is not None:
643
+ for block in self.blocks:
644
+ if block.is_superset and block.p2p_type == slowest_p2p_type:
645
+ for i in block.layer_range:
646
+ layer_name = self.network.get_layer(i).name
647
+ layer = new_graph.get_layer(layer_name)
648
+ layer.attrs["in_slowest_block"] = True
649
+
650
+ for i in range(self.network.num_outputs):
651
+ trt_output = self.network.get_output(i)
652
+ output = self.graph.get_output(trt_output.name)
653
+ if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
654
+ output.producer.index].is_subset:
655
+ continue
656
+ if trt_output.is_shape_tensor:
657
+ new_output = new_graph.add_output_shape(trt_output)
658
+ else:
659
+ new_output = new_graph.add_output(trt_output)
660
+ sharded_io = False
661
+ for pattern in self.sharded_io_allowlist:
662
+ if re.match(pattern, new_output.name):
663
+ sharded_io = True
664
+ break
665
+ if not sharded_io:
666
+ new_output.producer.attrs["is_replicated"] = True
667
+
668
+ for input in new_graph.inputs:
669
+ input_name = input.name
670
+ sharded_io = False
671
+ for pattern in self.sharded_io_allowlist:
672
+ if re.match(pattern, input_name):
673
+ sharded_io = True
674
+ break
675
+ if not sharded_io:
676
+ input.attrs["is_replicated"] = True
677
+ for pattern, repl in self.same_spec_io.items():
678
+ if re.match(pattern, input_name):
679
+ output_name = re.sub(pattern, repl, input_name)
680
+ output = new_graph.get_output(output_name)
681
+ if output is not None:
682
+ if "same_spec_id" in input.attrs:
683
+ same_spec_id = input.attrs["same_spec_id"]
684
+ else:
685
+ same_spec_id = same_spec_count
686
+ same_spec_count += 1
687
+ input.attrs["same_spec_id"] = same_spec_id
688
+ output.attrs["same_spec_id"] = same_spec_id
689
+ if math.prod(self.graph.get_input(
690
+ input_name).shape) < math.prod(
691
+ self.graph.get_output(output_name).shape):
692
+ input.attrs["no_memory_footprint"] = True
693
+ else:
694
+ output.attrs["no_memory_footprint"] = True
695
+
696
+ return new_graph, shape_mapping
697
+
698
+ def enrich_shape_info(self, shape_mapping):
699
+ shapes = self.shape_info.shapes.copy()
700
+ max_shapes = self.shape_info.max_shapes.copy()
701
+ values = self.shape_info.values.copy()
702
+ shape_layers = self.shape_info.shape_layers
703
+ for from_name, to_name in shape_mapping.items():
704
+ if to_name in shapes:
705
+ shapes[from_name] = shapes[to_name]
706
+ if to_name in max_shapes:
707
+ max_shapes[from_name] = max_shapes[to_name]
708
+ if to_name in values:
709
+ values[from_name] = values[to_name]
710
+ shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
711
+ return shape_info
712
+
713
+ def simplify_graph(
714
+ self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
715
+ num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
716
+ num_blocks = len(self.backbone_blocks)
717
+ if num_blocks % num_stages != 0:
718
+ return None, None
719
+ graph_config = GraphConfig()
720
+ graph_config.num_micro_batches = self.num_micro_batches
721
+ graph_config.num_blocks = num_blocks
722
+ graph_config.num_stages = num_stages
723
+ graph_config.phy_mesh = phy_mesh
724
+ stage_phy_meshes = phy_mesh.split_pipeline_meshes(
725
+ num_stages, num_devices_per_stage)
726
+ graph_config.stage_phy_meshes = stage_phy_meshes
727
+ with silent_trt_logger():
728
+ self.clean_blocks()
729
+ self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
730
+ graph_config.graph_mapping = self.get_graph_mapping()
731
+ new_graph, shape_mapping = self.create_simplified_graph(
732
+ graph_config)
733
+ shape_info = self.enrich_shape_info(shape_mapping)
734
+ new_graph.assign_shapes(shape_info)
735
+ return new_graph, graph_config
736
+
737
+ def get_graph_mapping_for_shape(self):
738
+ layer_mapping = {}
739
+ tensor_mapping = {}
740
+ for block_list in self.blocks_by_edge_hash.values():
741
+ head_block = block_list[0]
742
+ for block in block_list[1:]:
743
+ for from_layer_id, to_layer_id in zip(block.layer_range,
744
+ head_block.layer_range):
745
+ from_layer = self.network.get_layer(from_layer_id)
746
+ to_layer = self.network.get_layer(to_layer_id)
747
+ layer_mapping[from_layer.name] = to_layer.name
748
+ for i in range(from_layer.num_outputs):
749
+ tensor_mapping[from_layer.get_output(
750
+ i).name] = to_layer.get_output(i).name
751
+ return layer_mapping, tensor_mapping
752
+
753
+ def create_simplified_graph_for_shape(self):
754
+ new_graph = PipelineGraph.create_graph()
755
+
756
+ for i in range(self.network.num_inputs):
757
+ trt_input = self.network.get_input(i)
758
+ new_graph.add_input(trt_input)
759
+
760
+ head_blocks = {}
761
+ removed_blocks = set()
762
+ removed_layers = set()
763
+ for block_list in self.blocks_by_edge_hash.values():
764
+ head_block = block_list[0]
765
+ head_blocks[head_block.type_id] = head_block
766
+ for block in block_list[1:]:
767
+ removed_blocks.add(id(block))
768
+ for layer_index in block.layer_range:
769
+ removed_layers.add(layer_index)
770
+
771
+ for block in self.blocks:
772
+ if not id(block) in removed_blocks:
773
+ input_mapping = block.get_input_mapping(head_blocks)
774
+ for i in block.sorted_layer_ids:
775
+ layer = self.network.get_layer(i)
776
+ new_graph.add_layer(
777
+ layer,
778
+ input_mapping=input_mapping,
779
+ )
780
+
781
+ for i in range(self.network.num_outputs):
782
+ trt_output = self.network.get_output(i)
783
+ output = self.graph.get_output(trt_output.name)
784
+ if output.producer is not None and output.producer.index in removed_layers:
785
+ continue
786
+ if trt_output.is_shape_tensor:
787
+ new_graph.add_output_shape(trt_output)
788
+ else:
789
+ new_graph.add_output(trt_output)
790
+
791
+ return new_graph
792
+
793
+ def get_full_shape_info(self, num_micro_batches):
794
+ layer_mapping, tensor_mapping = self.graph_mapping_for_shape
795
+ optimization_profiles = self.llm_network._generate_optimization_profiles(
796
+ )
797
+ if len(optimization_profiles) > 0:
798
+ optimization_profile = optimization_profiles[-1]
799
+ else:
800
+ optimization_profile = None
801
+ shape_info = get_shape_info(self.graph_for_shape.as_trt(),
802
+ optimization_profile)
803
+ max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
804
+ optimization_profile,
805
+ shape_type=ShapeType.MAX)
806
+ shape_info.max_shapes = max_shape_info.shapes
807
+ for removed_tensor_name, tensor_name in tensor_mapping.items():
808
+ shape_info.shapes[removed_tensor_name] = shape_info.shapes[
809
+ tensor_name]
810
+ shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
811
+ tensor_name]
812
+ if tensor_name in shape_info.values:
813
+ shape_info.values[removed_tensor_name] = shape_info.values[
814
+ tensor_name]
815
+ for removed_layer_name, layer_name in layer_mapping.items():
816
+ if layer_name in shape_info.shape_layers:
817
+ shape_info.shape_layers.add(removed_layer_name)
818
+ return shape_info
819
+
820
+ def init_layer_hash(self):
821
+ with silent_trt_logger():
822
+ optimization_profiles = self.llm_network._generate_optimization_profiles(
823
+ )
824
+ if len(optimization_profiles) > 0:
825
+ optimization_profile = optimization_profiles[-1]
826
+ else:
827
+ optimization_profile = None
828
+ shape_info = get_shape_info(self.network, optimization_profile)
829
+ dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
830
+ for layer in self.graph.layers:
831
+ layer_hash = get_cache_key(
832
+ layer.as_trt(),
833
+ shape_info.shapes,
834
+ shape_info.values,
835
+ dtypes,
836
+ )
837
+ layer.attrs["hash"] = layer_hash
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/solver.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes.
2
+ """
3
+ import multiprocessing
4
+ import time
5
+ import warnings
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ import pulp
10
+ from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum
11
+
12
+ from ..logger import logger
13
+
14
+
15
+ class Solution:
16
+
17
+ def __init__(self, leaf_strategies, s_val, e_val, edge_pairs,
18
+ node_index_dict, total_cost):
19
+ self.leaf_strategies = leaf_strategies
20
+ self.nodes = [
21
+ strategies_vector.node for strategies_vector in self.leaf_strategies
22
+ ]
23
+ self.s_val = s_val
24
+ self.e_val = e_val
25
+ self.total_cost = total_cost
26
+ self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2)))
27
+ self.node_index_dict = node_index_dict
28
+ self.index_node_dict = {}
29
+ for node, index in self.node_index_dict.items():
30
+ self.index_node_dict[index] = node
31
+ self.node_best_strategy = {}
32
+ self._annotate_strategy()
33
+
34
+ def _annotate_strategy(self):
35
+ self.node_best_strategy = {}
36
+ for index, node in enumerate(self.nodes):
37
+ best_strategy_id = self.s_val[index]
38
+ best_strategy = self.leaf_strategies[index][best_strategy_id]
39
+ self.node_best_strategy[node.node_name] = best_strategy
40
+
41
+ for edge_idx, edge_pair in enumerate(self.edge_pairs):
42
+ src_node = self.index_node_dict[edge_pair[0]]
43
+ dst_node = self.index_node_dict[edge_pair[1]]
44
+ src_node_index = self.node_index_dict[src_node]
45
+ for dst_pre_node in dst_node.predecessor_nodes:
46
+ if dst_pre_node is None:
47
+ continue
48
+ if src_node.node_name == dst_pre_node.node_name:
49
+ self.node_best_strategy[
50
+ dst_node.node_name].best_resharding_cost[
51
+ src_node.node_name] = [
52
+ self.node_best_strategy[dst_node.node_name].
53
+ resharding_costs[src_node.node_name][
54
+ self.s_val[src_node_index]]
55
+ ]
56
+
57
+ def print_solution(self):
58
+ for index, node in enumerate(self.nodes):
59
+ best_strategy = self.node_best_strategy[node.node_name]
60
+ print(f'\n[{index}]: node_name = {node.node_name}')
61
+ best_strategy.print_strategy(best_resharding_cost_only=True)
62
+ print(f'solution total cost = {self.total_cost}')
63
+
64
+
65
+ class CostGraph:
66
+ '''
67
+ A graph data structure to simplify the edge cost graph. It has two main functions:
68
+ 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
69
+ CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
70
+ 2. To reduce the searching space, we merge computationally-trivial operators, such as
71
+ element-wise operators, transpose, and reduction, into their following nodes. The merging information will
72
+ be given by the StrategiesVector depending on the type of target node and following nodes.
73
+
74
+ Argument:
75
+ leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
76
+ simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
77
+ '''
78
+
79
+ def __init__(self, leaf_strategies):
80
+ self.leaf_strategies = leaf_strategies
81
+ self.nodes = [
82
+ strategies_vector.node for strategies_vector in leaf_strategies
83
+ ]
84
+ # stores number of strategies in each node
85
+ self.node_strategies_vector = {}
86
+ for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
87
+ self.node_strategies_vector[node] = strategies_vector
88
+ # extra_node_costs will store the extra costs introduced by merging nodes
89
+ self.extra_node_costs = {}
90
+ self.following_dict = {}
91
+ self._build_cost_graph()
92
+
93
+ def _remove_invalid_node(self, node, attr_name):
94
+ remove_list = []
95
+ target_node_list = getattr(node, attr_name, [])
96
+ for target_node in target_node_list:
97
+ if target_node not in self.nodes:
98
+ remove_list.append(target_node)
99
+ for element in remove_list:
100
+ target_node_list.remove(element)
101
+
102
+ def _build_cost_graph(self):
103
+ '''
104
+ This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
105
+ set to node.
106
+ '''
107
+ self.edge_costs = {}
108
+ for dst_node, strategies_vector in zip(self.nodes,
109
+ self.leaf_strategies):
110
+ # build edge_cost
111
+ for src_node in dst_node.predecessor_nodes:
112
+ if src_node is None:
113
+ continue
114
+ if src_node not in self.nodes:
115
+ continue
116
+ node_pair = (src_node, dst_node)
117
+ edge_cost = {}
118
+ for i in range(len(strategies_vector)):
119
+ for j in range(len(self.node_strategies_vector[src_node])):
120
+ resharding_cost = strategies_vector[i].resharding_costs[
121
+ src_node.node_name][j][-1]
122
+ edge_cost[(j, i)] = resharding_cost
123
+ self.edge_costs[node_pair] = edge_cost
124
+
125
+ def get_edge_cost(self, src_node, dst_node):
126
+ return self.edge_costs[(src_node, dst_node)]
127
+
128
+
129
+ class Solver:
130
+ INFINITY_COST = 1e13
131
+
132
+ def __init__(self,
133
+ cost_graph: CostGraph,
134
+ memory_budget: float = -1.0,
135
+ solution_numbers: int = 1,
136
+ memory_increasing_coefficient: float = 1.3,
137
+ verbose=False):
138
+ '''
139
+ Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
140
+ Argument:
141
+ graph: The computing graph to be optimized.
142
+ strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
143
+ cost_graph: A graph data structure to simplify the edge cost graph.
144
+ graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
145
+ memory_budget: Memory constraint for the solution.
146
+ solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
147
+ memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
148
+ '''
149
+ self.cost_graph = cost_graph
150
+ self.leaf_strategies = cost_graph.leaf_strategies
151
+ self.nodes = cost_graph.nodes
152
+ self.memory_budget = memory_budget
153
+ self.solution_numbers = solution_numbers
154
+ if self.solution_numbers > 1:
155
+ self.memory_increasing_coefficient = memory_increasing_coefficient
156
+ else:
157
+ self.memory_increasing_coefficient = 1
158
+ # temporarily we use all nodes as liveness list, we count the backward memory cost together with
159
+ # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
160
+ # self.liveness_list = self.graph_analyser.liveness_analysis()
161
+ self.liveness_list = self.nodes
162
+ self.node_index_dict = self._generate_node_index_dict()
163
+ # The last solution vector of auto sharding.
164
+ self.last_s_val = None
165
+ # The last objective value of the best ILP solution.
166
+ self.last_objective = None
167
+ self.verbose = verbose
168
+
169
+ def _generate_node_index_dict(self):
170
+ node_index_dict = {}
171
+ for index, node in enumerate(self.nodes):
172
+ node_index_dict[node] = index
173
+ return node_index_dict
174
+
175
+ def _prepare_data_for_solver(self):
176
+ '''
177
+ Extract information from components for solver.
178
+ '''
179
+ node_nums = len(self.leaf_strategies)
180
+ memory_budget = self.memory_budget
181
+
182
+ # prepare strategies_len
183
+ strategies_len = []
184
+ for node in self.nodes:
185
+ strategies_len.append(
186
+ len(self.cost_graph.node_strategies_vector[node]))
187
+ strategies_len = np.array(strategies_len)
188
+
189
+ # prepare edge_pairs and resharding costs
190
+ edge_pairs = []
191
+ resharding_costs = []
192
+ edge_cost_level = []
193
+ edge_resharding_weights = []
194
+ for pairs, edge_cost in self.cost_graph.edge_costs.items():
195
+ src_node = pairs[0]
196
+ dst_node = pairs[1]
197
+ src_node_index = self.node_index_dict[src_node]
198
+ dst_node_index = self.node_index_dict[dst_node]
199
+ edge_pairs.append(src_node_index)
200
+ edge_pairs.append(dst_node_index)
201
+ edge_cost_level.append(
202
+ (dst_node.building_block_id, dst_node.cost_level))
203
+ for i in range(strategies_len[src_node_index]):
204
+ for j in range(strategies_len[dst_node_index]):
205
+ resharding_costs.append(edge_cost[(i, j)])
206
+ edge_resharding_weights.append(dst_node.resharding_weight +
207
+ dst_node.pipeline_weight)
208
+ edge_pairs = np.array(edge_pairs)
209
+ resharding_costs = np.array(resharding_costs)
210
+ edge_resharding_weights = np.array(edge_resharding_weights)
211
+ # prepare compute_costs, communication_costs and memory_costs
212
+ compute_costs = []
213
+ communication_costs = []
214
+ memory_costs = []
215
+ peak_act_memory_costs, constant_memory_costs = [], []
216
+ node_sharding_weights = []
217
+ for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
218
+ for index, strategy in enumerate(strategies_vector):
219
+ compute_cost = strategy.sharding_cost
220
+ origin_communication_cost = strategy.communication_cost
221
+ memory_cost = strategy.const_memory_footprint * node.sharding_weight
222
+ peak_act_memory = strategy.peak_memory_footprint
223
+ # extract the memory cost in float from MemoryCost item and sum them up
224
+ compute_costs.append(compute_cost)
225
+ # node in extra_node_costs means it has some extra communication
226
+ # cost from node merging, so we need to add those extra communication
227
+ # cost into
228
+
229
+ communication_costs.append(origin_communication_cost)
230
+ peak_act_memory_costs.append(peak_act_memory)
231
+ constant_memory_costs.append(memory_cost)
232
+ node_sharding_weights.append(node.sharding_weight +
233
+ node.pipeline_weight)
234
+
235
+ compute_costs = np.array(compute_costs)
236
+ communication_costs = np.array(communication_costs)
237
+ memory_costs = np.array([constant_memory_costs, peak_act_memory_costs])
238
+ node_sharding_weights = np.array(node_sharding_weights)
239
+ same_spec_nodes_dict = defaultdict(list)
240
+ node_cost_level = []
241
+ for idx, node in enumerate(self.nodes):
242
+ if node.same_spec_id >= 0:
243
+ same_spec_nodes_dict[node.same_spec_id].append(idx)
244
+ node_cost_level.append((node.building_block_id, node.cost_level))
245
+ # omit initial value for nodes
246
+ s_init_np = None
247
+ following_nodes = [-1 for i in range(node_nums)]
248
+ liveness_set = self.nodes
249
+ alias_set = []
250
+ alias_convert_costs = None
251
+ return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose
252
+
253
+ def _call_solver_serialized_args(self,
254
+ node_nums,
255
+ memory_budget,
256
+ strategies_len,
257
+ following_nodes,
258
+ edge_pairs,
259
+ alias_set,
260
+ liveness_set,
261
+ compute_costs,
262
+ communication_costs,
263
+ memory_costs,
264
+ resharding_costs,
265
+ node_sharding_weights,
266
+ edge_resharding_weights,
267
+ same_spec_nodes_dict,
268
+ node_cost_level,
269
+ edge_cost_level,
270
+ alias_convert_costs,
271
+ s_init_np=None,
272
+ verbose=True):
273
+ """
274
+ Call the solver with serialized arguments.
275
+ """
276
+
277
+ time.time()
278
+
279
+ for x in [
280
+ strategies_len, edge_pairs, compute_costs, communication_costs,
281
+ memory_costs, resharding_costs, node_sharding_weights,
282
+ edge_resharding_weights
283
+ ]:
284
+ assert isinstance(x, np.ndarray)
285
+ assert len(strategies_len) == node_nums, "strategies_len"
286
+
287
+ def get_non_zero_index(binary_vector):
288
+ """
289
+ Get the index of non-zero item in a vector.
290
+ """
291
+ ct = 0
292
+ ret = None
293
+ for i, elem in enumerate(binary_vector):
294
+ if pulp.value(elem):
295
+ ret = i
296
+ ct += 1
297
+
298
+ assert ct == 1
299
+ return ret
300
+
301
+ # 0. Unpack flatten numpy arrays
302
+ s_follow = following_nodes
303
+ s_alias = alias_set
304
+
305
+ E = edge_pairs.reshape((-1, 2)) # noqa
306
+ r = []
307
+ pt = 0
308
+ edge_set = set()
309
+ for (i, j) in E:
310
+ prod_length = strategies_len[i] * strategies_len[j]
311
+
312
+ if (i, j) in edge_set:
313
+ raise ValueError(f"Duplicated edges: {(i, j)}")
314
+
315
+ edge_set.add((i, j))
316
+ r.append(resharding_costs[pt:pt + prod_length])
317
+ pt += prod_length
318
+ assert pt == len(resharding_costs)
319
+
320
+ ######################
321
+ # omit alias set now #
322
+ ######################
323
+
324
+ # A = alias_set.reshape((-1, 2)) # noqa
325
+ # for (i, j) in A:
326
+ # prod_length = strategies_len[i] * strategies_len[j]
327
+ # v.append(alias_convert_costs[pt:pt + prod_length])
328
+ # pt += prod_length
329
+ # assert pt == len(alias_convert_costs)
330
+
331
+ # L = [] # noqa
332
+ # pt = node_nums
333
+ # for i in range(node_nums):
334
+ # length = liveness_set[i]
335
+ # L.append(liveness_set[pt:pt + length])
336
+ # pt += length
337
+ # assert pt == len(liveness_set)
338
+ pt = 0
339
+
340
+ c = []
341
+ d = []
342
+ m = []
343
+ peak_m = []
344
+ pt = 0
345
+ for i in range(node_nums):
346
+ length = strategies_len[i]
347
+ c.append(compute_costs[pt:pt + length])
348
+ d.append(communication_costs[pt:pt + length])
349
+ m.append(memory_costs[0][pt:pt + length])
350
+ peak_m.append(memory_costs[1][pt:pt + length])
351
+ pt += length
352
+ assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
353
+ assert pt == len(
354
+ communication_costs), f"{pt} == {len(communication_costs)}"
355
+ assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}"
356
+
357
+ # 1. Create variables
358
+
359
+ #############################
360
+ # create variables for node #
361
+ #############################
362
+ s = []
363
+ num_nodes = 0
364
+ reverse_follow_backpatch = []
365
+ for i in range(node_nums):
366
+ if s_follow[i] < 0:
367
+ if strategies_len[i] == 1:
368
+ s.append([1])
369
+ else:
370
+ if i not in s_alias:
371
+ num_nodes += 1
372
+ s.append(
373
+ LpVariable.matrix(f"s[{i}]",
374
+ (range(strategies_len[i]), ),
375
+ cat="Binary"))
376
+ else:
377
+ s.append(s[s_alias[i]])
378
+ else:
379
+ if s_follow[i] < len(s):
380
+ s.append(s[s_follow[i]])
381
+ else:
382
+ s.append(None)
383
+ reverse_follow_backpatch.append(i)
384
+
385
+ for i in reverse_follow_backpatch:
386
+ s[i] = s[s_follow[i]]
387
+
388
+ #############################
389
+ # create variables for edge #
390
+ #############################
391
+ e = []
392
+ num_edges = 0
393
+ map_edge_to_idx = {}
394
+ for (idx, (i, j)) in enumerate(E):
395
+ if len(s[i]) == 1:
396
+ e.append(s[j])
397
+ elif len(s[j]) == 1:
398
+ e.append(s[i])
399
+ else:
400
+ if i in s_alias and j in s_alias and (
401
+ s_alias[i], s_alias[j]) in map_edge_to_idx:
402
+ e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
403
+ else:
404
+ num_edges += 1
405
+ e.append(
406
+ LpVariable.matrix(f"e[{i},{j}]",
407
+ (range(len(s[i]) * len(s[j])), ),
408
+ cat="Binary"))
409
+ assert len(e[idx]) == len(r[idx])
410
+ map_edge_to_idx[(i, j)] = idx
411
+ for element in s:
412
+ assert len(element) > 0
413
+ # 2. Set initial value
414
+ ######################################
415
+ # set a initial value for warm start #
416
+ ######################################
417
+ if s_init_np is not None:
418
+ s_init = s_init_np.reshape((-1, 3))
419
+ for (idx, value, fix) in s_init:
420
+ for i in range(len(s[idx])):
421
+ s[idx][i].setInitialValue(i == value)
422
+ if fix:
423
+ s[idx][i].fixValue()
424
+
425
+ # 3. Objective
426
+ prob = LpProblem("myProblem", LpMinimize)
427
+ ###################################################################
428
+ # computing the node cost(computing cost and communication cost) #
429
+ ###################################################################
430
+ obj = 0
431
+ block_cost_level_dict = {}
432
+ for i in range(node_nums):
433
+ assert len(s[i]) == len(c[i])
434
+ assert len(s[i]) == len(d[i])
435
+ obj += (lpDot(s[i], c[i]) +
436
+ lpDot(s[i], d[i])) * node_sharding_weights[i]
437
+ cost_level = node_cost_level[i]
438
+ if -1 != cost_level[1]:
439
+ if cost_level in block_cost_level_dict:
440
+ block_cost_level_dict[cost_level] += lpDot(
441
+ s[i], c[i]) + lpDot(s[i], d[i])
442
+ else:
443
+ block_cost_level_dict[cost_level] = lpDot(
444
+ s[i], c[i]) + lpDot(s[i], d[i])
445
+
446
+ #############################################
447
+ # computing the edge cost(resharding cost) #
448
+ #############################################
449
+
450
+ for i in range(len(E)):
451
+ assert len(e[i]) == len(r[i])
452
+ obj += lpDot(e[i], r[i]) * edge_resharding_weights[i]
453
+ cost_level = edge_cost_level[i]
454
+ if -1 != cost_level[1]:
455
+ if cost_level in block_cost_level_dict:
456
+ block_cost_level_dict[cost_level] += lpDot(e[i], r[i])
457
+ else:
458
+ block_cost_level_dict[cost_level] = lpDot(e[i], r[i])
459
+ prob += obj
460
+ if len(block_cost_level_dict) >= 2:
461
+ block_cost_levels = [key for key in block_cost_level_dict.keys()]
462
+ for i in range(len(block_cost_levels)):
463
+ for j in range(i + 1, len(block_cost_levels)):
464
+ if block_cost_levels[i][1] > block_cost_levels[j][1]:
465
+ prob += block_cost_level_dict[
466
+ block_cost_levels[i]] >= block_cost_level_dict[
467
+ block_cost_levels[j]] + 1e-6
468
+ elif block_cost_levels[i][1] < block_cost_levels[j][1]:
469
+ prob += block_cost_level_dict[
470
+ block_cost_levels[j]] >= block_cost_level_dict[
471
+ block_cost_levels[i]] + 1e-6
472
+ # 4. Constraints
473
+ # (a). specified by `cat="Binary"`
474
+
475
+ # (b)
476
+ #################################################
477
+ # make sure each node only choose one strategy #
478
+ #################################################
479
+ for i in range(node_nums):
480
+ if s_follow[i] < 0:
481
+ prob += lpSum(s[i]) == 1
482
+
483
+ # (c)
484
+ #################################################
485
+ # force to constrain some nodes have the same sharding specs #
486
+ #################################################
487
+ for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items():
488
+ num_same_spec_nodes = len(same_spec_nodes_id)
489
+ if num_same_spec_nodes >= 2:
490
+ src_node_s = s[same_spec_nodes_id[0]]
491
+ num_specs = len(src_node_s)
492
+ for i in range(1, num_same_spec_nodes):
493
+ dst_node_s = s[same_spec_nodes_id[i]]
494
+ assert len(
495
+ dst_node_s
496
+ ) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs'
497
+ for j in range(num_specs):
498
+ prob += (src_node_s[j] == dst_node_s[j])
499
+
500
+ # (c)
501
+ #################################################
502
+ # compute memory consumption with liveness set #
503
+ #################################################
504
+ if memory_budget > 0:
505
+ # calculate the constant memory
506
+ mem = 0
507
+ for node in liveness_set:
508
+ if node not in self.node_index_dict:
509
+ continue
510
+ node_index = self.node_index_dict[node]
511
+ mem += lpSum(s[node_index][j] * m[node_index][j]
512
+ for j in range(len(s[node_index])))
513
+ # calculate the peak activation memory
514
+ for node in liveness_set:
515
+ if node not in self.node_index_dict:
516
+ continue
517
+ node_index = self.node_index_dict[node]
518
+ cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j]
519
+ for j in range(len(s[node_index])))
520
+ total_mem = mem + cur_peak_mem
521
+ prob += total_mem <= memory_budget
522
+
523
+ # (d). specified by `cat="Binary"`
524
+
525
+ for (idx, (i, j)) in enumerate(E):
526
+ if strategies_len[i] == 1 or strategies_len[j] == 1:
527
+ continue
528
+
529
+ # (e)
530
+ prob += lpSum(e[idx]) == 1
531
+
532
+ # (f)
533
+ for row in range(len(s[i])):
534
+ C = len(s[j]) # noqa
535
+ prob += lpSum(e[idx][row * C + col]
536
+ for col in range(0, C)) <= s[i][row]
537
+
538
+ # (g)
539
+ for col in range(len(s[j])):
540
+ R = len(s[i]) # noqa
541
+ C = len(s[j]) # noqa
542
+ prob += lpSum(e[idx][row * C + col]
543
+ for row in range(0, R)) <= s[j][col]
544
+
545
+ if prob.objective.isNumericalConstant():
546
+ objective = float(pulp.value(prob.objective))
547
+ status = pulp.LpStatusOptimal
548
+ else:
549
+ msg = verbose
550
+ time_limit = 600
551
+ solver = pulp.PULP_CBC_CMD(
552
+ mip=True,
553
+ msg=msg,
554
+ timeLimit=time_limit,
555
+ threads=multiprocessing.cpu_count(),
556
+ )
557
+ prob.solve(solver)
558
+
559
+ status = prob.status
560
+ objective = pulp.value(prob.objective)
561
+ objective = float(
562
+ objective) if objective is not None else self.INFINITY_COST
563
+
564
+ if prob.status in [pulp.LpStatusInfeasible]:
565
+ objective = self.INFINITY_COST
566
+
567
+ # Get and check results
568
+ s_val = np.full((node_nums, ), -1, dtype=np.int32)
569
+ for i in range(node_nums):
570
+ s_val[i] = get_non_zero_index(s[i])
571
+
572
+ e_val = np.full((len(E), ), -1, dtype=np.int32)
573
+ for (idx, (i, j)) in enumerate(E):
574
+ e_val[idx] = get_non_zero_index(e[idx])
575
+ i_spec_index = e_val[idx] // len(s[j])
576
+ j_spec_index = e_val[idx] % len(s[j])
577
+ assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
578
+ assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
579
+ if verbose and r[idx][e_val[idx]] > 0:
580
+ print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
581
+
582
+ self.last_s_val = list(s_val)
583
+ # self._recover_merged_node_strategy()
584
+ self.last_objective = objective
585
+
586
+ if objective >= self.INFINITY_COST:
587
+ warnings.warn(
588
+ f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \
589
+ f"1. increase memory budget if possible\n" + \
590
+ f"2. enlarge mesh shape if possible\n" + \
591
+ f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config")
592
+ if memory_budget > 0:
593
+ # calculate the constant memory
594
+ mem = 0
595
+ for node in liveness_set:
596
+ if node not in self.node_index_dict:
597
+ continue
598
+ node_index = self.node_index_dict[node]
599
+ j = self.last_s_val[node_index]
600
+ mem += m[node_index][j]
601
+ max_peak_mem = 0
602
+ for node in liveness_set:
603
+ if node not in self.node_index_dict:
604
+ continue
605
+ node_index = self.node_index_dict[node]
606
+ j = self.last_s_val[node_index]
607
+ cur_peak_mem = peak_m[node_index][j]
608
+ max_peak_mem = max(max_peak_mem, cur_peak_mem)
609
+ logger.debug(
610
+ f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}'
611
+ )
612
+
613
+ solution = Solution(self.leaf_strategies, self.last_s_val, e_val,
614
+ edge_pairs, self.node_index_dict,
615
+ self.last_objective)
616
+ return status, solution
617
+
618
+ def find_solution(self):
619
+ """
620
+ Call the solver with serialized arguments and handle python errors. Additionally,
621
+ we could give a serious of solutions with different memory budget.
622
+ """
623
+ if self.solution_numbers == 1:
624
+ args = self._prepare_data_for_solver()
625
+ ret = self._call_solver_serialized_args(*args)
626
+
627
+ return ret
628
+
629
+ origin_memory_budget = self.memory_budget
630
+ memory_budget_list = [
631
+ origin_memory_budget * self.memory_increasing_coefficient**i
632
+ for i in range(self.solution_numbers)
633
+ ]
634
+ ret_list = []
635
+ for memory_budget in memory_budget_list:
636
+ self.memory_budget = memory_budget
637
+ args = self._prepare_data_for_solver()
638
+ ret = self._call_solver_serialized_args(*args)
639
+ ret_list.append(ret)
640
+
641
+ return ret_list
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py ADDED
File without changes
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Activation(Node):
8
+
9
+ def _collect_strategies(self, device_mesh):
10
+ dim_partition_list = []
11
+ dim_size = len(self.op_data['input0'].shape)
12
+ dim_partition_list.append({})
13
+ dim_partition_list.extend(
14
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
15
+ dim_partition_list.extend(
16
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
17
+ dim_partition_list.extend(
18
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
19
+ dim_partition_list.extend(
20
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
21
+ strategies_vector = StrategiesVector(self)
22
+ for dim_partition_dict in dim_partition_list:
23
+ in0_partition_dict = dim_partition_dict
24
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
25
+ dim_partition_dict_mapping = {
26
+ "input0": in0_partition_dict,
27
+ "output0": out_partition_dict,
28
+ }
29
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
30
+ dim_partition_dict_mapping, device_mesh)
31
+ if 0 == len(sharding_spec_mapping):
32
+ continue
33
+ name = '{} = <activation op> {}'.format(
34
+ sharding_spec_mapping['output0'].sharding_sequence,
35
+ sharding_spec_mapping['input0'].sharding_sequence)
36
+ sharding_strategy = self._get_sharding_strategy(
37
+ name=name,
38
+ sharding_spec_mapping=sharding_spec_mapping,
39
+ communication_action_mapping={})
40
+ strategies_vector.append(sharding_strategy)
41
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Assertion(Node):
8
+
9
+ def _collect_strategies(self, device_mesh):
10
+ predecessor = self.predecessor_nodes[0] # one input for softmax node
11
+ strategies_vector = StrategiesVector(self)
12
+ for idx, strategy in enumerate(predecessor.strategies_vector):
13
+ global_input_name = self.op_data[
14
+ 'input0'].name # current node's local name input0 -> global name xxx
15
+ prenode_local_name = predecessor.global_to_local_op_name[
16
+ global_input_name] # global name xxx -> pre node local output name
17
+ dim_partition_dict = copy.deepcopy(
18
+ strategy.sharding_specs[prenode_local_name].dim_partition_dict)
19
+ in0_partition_dict = dim_partition_dict
20
+ dim_partition_dict_mapping = {
21
+ "input0": in0_partition_dict,
22
+ }
23
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
24
+ dim_partition_dict_mapping, device_mesh)
25
+ if 0 == len(sharding_spec_mapping):
26
+ return strategies_vector
27
+ name = '<assertion> {}'.format(
28
+ sharding_spec_mapping['input0'].sharding_sequence)
29
+ sharding_strategy = self._get_sharding_strategy(
30
+ name=name,
31
+ sharding_spec_mapping=sharding_spec_mapping,
32
+ communication_action_mapping={})
33
+ strategies_vector.append(sharding_strategy)
34
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Cast(Node):
8
+
9
+ def _collect_strategies(self, device_mesh):
10
+ dim_partition_list = []
11
+ dim_size = len(self.op_data['input0'].shape)
12
+ dim_partition_list.append({})
13
+ dim_partition_list.extend(
14
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
15
+ dim_partition_list.extend(
16
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
17
+ dim_partition_list.extend(
18
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
19
+ dim_partition_list.extend(
20
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
21
+ strategies_vector = StrategiesVector(self)
22
+ for dim_partition_dict in dim_partition_list:
23
+ in0_partition_dict = dim_partition_dict
24
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
25
+ dim_partition_dict_mapping = {
26
+ "input0": in0_partition_dict,
27
+ "output0": out_partition_dict,
28
+ }
29
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
30
+ dim_partition_dict_mapping, device_mesh)
31
+ if 0 == len(sharding_spec_mapping):
32
+ continue
33
+ name = '{} = <cast op> {}'.format(
34
+ sharding_spec_mapping['output0'].sharding_sequence,
35
+ sharding_spec_mapping['input0'].sharding_sequence)
36
+ sharding_strategy = self._get_sharding_strategy(
37
+ name=name,
38
+ sharding_spec_mapping=sharding_spec_mapping,
39
+ communication_action_mapping={})
40
+ strategies_vector.append(sharding_strategy)
41
+
42
+ return strategies_vector
43
+
44
+ def _update_memory_cost(self, strategies):
45
+ pass
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ 'CommSpec',
3
+ ]
4
+
5
+
6
+ class CommSpec:
7
+
8
+ def __init__(self,
9
+ comm_pattern,
10
+ sharding_spec,
11
+ gather_dim=None,
12
+ shard_dim=None,
13
+ logical_process_axis=None,
14
+ mix_gather=False,
15
+ forward_only=True):
16
+ self.comm_pattern = comm_pattern
17
+ self.sharding_spec = sharding_spec
18
+ self.gather_dim = gather_dim
19
+ self.shard_dim = shard_dim
20
+ self.logical_process_axis = logical_process_axis
21
+ self.device_mesh = self.sharding_spec.device_mesh
22
+ self.mix_gather = mix_gather
23
+ self.forward_only = forward_only
24
+ if self.gather_dim:
25
+ assert len(self.gather_dim) == len(
26
+ self.logical_process_axis
27
+ ), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}'
28
+ if self.shard_dim:
29
+ assert len(self.shard_dim) == len(
30
+ self.logical_process_axis
31
+ ), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}'
32
+ if self.gather_dim and self.shard_dim:
33
+ assert len(self.shard_dim) == len(
34
+ self.gather_dim
35
+ ), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}'
36
+
37
+ def get_comm_cost(self):
38
+ '''
39
+ For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
40
+ compute the communication cost.
41
+ For shard operation, it is an on-chip operation, so the communication cost is zero.
42
+ '''
43
+ comm_size = self.sharding_spec.get_sharded_size_per_device()
44
+ dtype = self.sharding_spec.dtype
45
+
46
+ # reduce list_of_list to list
47
+ comm_dims = sum(self.logical_process_axis, [])
48
+ comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern,
49
+ comm_dims, comm_size,
50
+ dtype)
51
+ return comm_cost
52
+
53
+ def get_mem_cost(self):
54
+ return self.device_mesh.shape_consistency_manager.mem_cost([self])
55
+
56
+ def get_max_mem_cost(self):
57
+ return self.device_mesh.shape_consistency_manager.mem_cost(
58
+ [self], mem_pattern='max')
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Concatenation(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ batch_dims = [i for i in range(len(self.get_output(0).shape))]
13
+ self.axis = layer.as_trt().axis
14
+ batch_dims.remove(self.axis)
15
+ self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
16
+ layer.to_base_class()
17
+
18
+ def _collect_strategies(self, device_mesh):
19
+ dim_partition_list = []
20
+ dim_size = len(self.op_data['output0'].shape)
21
+ dim_partition_list.append({})
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
24
+ dim_partition_list.extend(
25
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
26
+ dim_partition_list.extend(
27
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
28
+ dim_partition_list.extend(
29
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
30
+
31
+ dim_partition_dict_mapping = {}
32
+ strategies_vector = StrategiesVector(self)
33
+ for dim_partition_dict in dim_partition_list:
34
+ if self.axis in dim_partition_dict:
35
+ dim_partition_dict.pop(self.axis)
36
+ for idx in range(self.num_inputs):
37
+ in_partition_dict = copy.deepcopy(dim_partition_dict)
38
+ dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict
39
+ out_partition_dict = dim_partition_dict
40
+ dim_partition_dict_mapping['output0'] = out_partition_dict
41
+
42
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
43
+ dim_partition_dict_mapping, device_mesh)
44
+ if 0 == len(sharding_spec_mapping):
45
+ continue
46
+ name = '{} = <concate along dim {}> {}'.format(
47
+ sharding_spec_mapping['output0'].sharding_sequence, self.axis, [
48
+ sharding_spec_mapping[f'input{idx}'].sharding_sequence
49
+ for idx in range(self.num_inputs)
50
+ ])
51
+ sharding_strategy = self._get_sharding_strategy(
52
+ name=name,
53
+ sharding_spec_mapping=sharding_spec_mapping,
54
+ communication_action_mapping={})
55
+ strategies_vector.append(sharding_strategy)
56
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class Constant(Node):
6
+
7
+ def _update_memory_cost(self, strategies):
8
+ super()._update_memory_cost(strategies)
9
+ for strategy in strategies:
10
+ strategy.inout_memory_footprint = 0.0
11
+ strategy.peak_memory_footprint = 0.0
12
+ strategy.const_memory_footprint = strategy.sharding_specs[
13
+ 'output0'].get_max_sharded_size_per_device()
14
+
15
+ def _collect_strategies(self, device_mesh):
16
+ dim_partition_list = []
17
+ dim_size = len(self.op_data['output0'].shape)
18
+ dim_partition_list.append({})
19
+ dim_partition_list.extend(
20
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
21
+ dim_partition_list.extend(
22
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
23
+ dim_partition_list.extend(
24
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
25
+ dim_partition_list.extend(
26
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
27
+
28
+ strategies_vector = StrategiesVector(self)
29
+ for dim_partition_dict in dim_partition_list:
30
+ dim_partition_dict_mapping = {'output0': dim_partition_dict}
31
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
32
+ dim_partition_dict_mapping, device_mesh)
33
+ if 0 == len(sharding_spec_mapping):
34
+ continue
35
+ sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
36
+ sharding_strategy = self._get_sharding_strategy(
37
+ name=f'constant-op {sharding_seq}',
38
+ sharding_spec_mapping=sharding_spec_mapping,
39
+ communication_action_mapping={})
40
+ strategies_vector.append(sharding_strategy)
41
+
42
+ return strategies_vector
43
+
44
+ def _profile_sharding_cost(self, strategy, device_mesh):
45
+ return 0.0
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class ElementWise(Node):
6
+
7
+ def __init__(self, layer):
8
+ super().__init__(layer)
9
+ batch_dims = [i for i in range(len(self.get_output(0).shape))]
10
+ self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
11
+
12
+ def _collect_strategies(self, device_mesh):
13
+ dim_partition_list = []
14
+ dim_size = len(self.op_data['output0'].shape)
15
+ dim_partition_list.append({})
16
+ dim_partition_list.extend(
17
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
18
+ dim_partition_list.extend(
19
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
20
+ dim_partition_list.extend(
21
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
24
+ strategies_vector = StrategiesVector(self)
25
+ for dim_partition_dict in dim_partition_list:
26
+ in0_partition_dict = self._recover_bcast_partition_dict(
27
+ dim_partition_dict, self.op_data['input0'])
28
+ in1_partition_dict = self._recover_bcast_partition_dict(
29
+ dim_partition_dict, self.op_data['input1'])
30
+ out_partition_dict = dim_partition_dict
31
+ dim_partition_dict_mapping = {
32
+ "input0": in0_partition_dict,
33
+ "input1": in1_partition_dict,
34
+ "output0": out_partition_dict,
35
+ }
36
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
37
+ dim_partition_dict_mapping, device_mesh)
38
+ if 0 == len(sharding_spec_mapping):
39
+ continue
40
+ name = '{} = {} <elementwise> {}'.format(
41
+ sharding_spec_mapping['output0'].sharding_sequence,
42
+ sharding_spec_mapping['input0'].sharding_sequence,
43
+ sharding_spec_mapping['input1'].sharding_sequence)
44
+ sharding_strategy = self._get_sharding_strategy(
45
+ name=name,
46
+ sharding_spec_mapping=sharding_spec_mapping,
47
+ communication_action_mapping={})
48
+ strategies_vector.append(sharding_strategy)
49
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Fill(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ self.operation = layer.as_trt().operation
13
+ layer.to_base_class()
14
+
15
+ def _collect_strategies(self, device_mesh):
16
+ dim_partition_list = []
17
+ dim_size = len(self.op_data['output0'].shape)
18
+ dim_partition_list.append({})
19
+ if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE:
20
+ dim_partition_list.extend(
21
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
24
+ dim_partition_list.extend(
25
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
26
+ dim_partition_list.extend(
27
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
28
+
29
+ strategies_vector = StrategiesVector(self)
30
+ for dim_partition_dict in dim_partition_list:
31
+ dim_partition_dict_mapping = {'output0': dim_partition_dict}
32
+ for i in range(self.num_inputs):
33
+ dim_partition_dict_mapping[f'input{i}'] = {}
34
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
35
+ dim_partition_dict_mapping, device_mesh)
36
+ if 0 == len(sharding_spec_mapping):
37
+ continue
38
+ sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
39
+ sharding_strategy = self._get_sharding_strategy(
40
+ name=f'fill-op {sharding_seq}',
41
+ sharding_spec_mapping=sharding_spec_mapping,
42
+ communication_action_mapping={})
43
+ strategies_vector.append(sharding_strategy)
44
+
45
+ return strategies_vector
46
+
47
+ def _profile_sharding_cost(self, strategy, device_mesh):
48
+ updated_layer_attrs = {}
49
+ updated_input_values = {}
50
+ shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
51
+ )
52
+ if self.layer.num_inputs >= 1:
53
+ updated_input_values[0] = shape
54
+ else:
55
+ updated_layer_attrs['shape'] = shape
56
+ elapsed_time = self.node_runtime_profiler.runtime_profile(
57
+ self.layer, updated_layer_attrs, updated_input_values, strategy,
58
+ device_mesh)
59
+ return elapsed_time
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import tensorrt as trt
4
+
5
+ from .comm_spec import CommSpec
6
+ from .node import Node
7
+ from .sharding_spec import DimSpec
8
+ from .sharding_strategy import StrategiesVector
9
+
10
+
11
+ class Gather(Node):
12
+
13
+ def __init__(self, layer):
14
+ super().__init__(layer)
15
+ layer.to_subclass()
16
+ self.mode = layer.as_trt().mode
17
+ self.axis = layer.as_trt().axis
18
+ self.num_elementwise_dims = layer.as_trt().num_elementwise_dims
19
+ self.input_id = 0
20
+ self.indice_id = 1
21
+ self.support_vocab_tp = False
22
+ layer.to_base_class()
23
+
24
+ def _update_memory_cost(self, strategies):
25
+ for strategy in strategies:
26
+ # for gather node, it input0's read = output0's write
27
+ inout_memory_footprint = (
28
+ strategy.sharding_specs['output0'].get_sharded_size_per_device(
29
+ ) * 2 +
30
+ strategy.sharding_specs['input1'].get_sharded_size_per_device())
31
+ strategy.inout_memory_footprint = inout_memory_footprint
32
+ strategy.peak_memory_footprint = (
33
+ strategy.sharding_specs['output0'].
34
+ get_max_sharded_size_per_device() + strategy.
35
+ sharding_specs['input0'].get_max_sharded_size_per_device() +
36
+ strategy.sharding_specs['input1'].
37
+ get_max_sharded_size_per_device())
38
+
39
+ def _collect_strategies(self, device_mesh):
40
+ if self.mode == trt.GatherMode.DEFAULT:
41
+ return self._default_gather_strategies(device_mesh)
42
+ elif self.mode == trt.GatherMode.ELEMENT:
43
+ return self._element_gather_strategies(device_mesh)
44
+ elif self.mode == trt.GatherMode.ND:
45
+ assert 0, 'unsupport gatherND'
46
+ else:
47
+ assert 0, f'unsupport gather mode {self.mode}'
48
+
49
+ def _element_gather_strategies(self, device_mesh):
50
+ dim_partition_list = []
51
+ dim_size = len(self.op_data['output0'].shape)
52
+ dim_partition_list.append({})
53
+ dim_partition_list.extend(
54
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
55
+ dim_partition_list.extend(
56
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
57
+ dim_partition_list.extend(
58
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
59
+ dim_partition_list.extend(
60
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
61
+
62
+ strategies_vector = StrategiesVector(self)
63
+ for dim_partition_dict in dim_partition_list:
64
+ if self.axis in dim_partition_dict:
65
+ dim_partition_dict.pop(self.axis)
66
+
67
+ dim_partition_dict_mapping = {
68
+ 'input0': dim_partition_dict,
69
+ 'input1': copy.deepcopy(dim_partition_dict),
70
+ 'output0': copy.deepcopy(dim_partition_dict),
71
+ }
72
+
73
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
74
+ dim_partition_dict_mapping, device_mesh)
75
+ if 0 == len(sharding_spec_mapping):
76
+ continue
77
+ name = '{} = {} <element gather op axis {}> {}'.format(
78
+ sharding_spec_mapping['output0'].sharding_sequence,
79
+ sharding_spec_mapping['input0'].sharding_sequence, self.axis,
80
+ sharding_spec_mapping['input1'].sharding_sequence)
81
+ sharding_strategy = self._get_sharding_strategy(
82
+ name=name,
83
+ sharding_spec_mapping=sharding_spec_mapping,
84
+ communication_action_mapping={})
85
+ strategies_vector.append(sharding_strategy)
86
+
87
+ return strategies_vector
88
+
89
+ # for plugin, indice is input0, and weight is input1, which is different from gather node
90
+ def _default_gather_strategies(self, device_mesh):
91
+
92
+ def add_sharding_strategy(dim_partition_dict_mapping,
93
+ vocab_tp_dim=None):
94
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
95
+ dim_partition_dict_mapping, device_mesh)
96
+ if len(sharding_spec_mapping) > 0:
97
+ name = '{} = {} <default gather op axis {}, num_elementwise_dims {}> {}'.format(
98
+ sharding_spec_mapping['output0'].sharding_sequence,
99
+ sharding_spec_mapping['input0'].sharding_sequence,
100
+ self.axis, self.num_elementwise_dims,
101
+ sharding_spec_mapping['input1'].sharding_sequence)
102
+ communication_action_mapping = {}
103
+ if vocab_tp_dim is not None:
104
+ name += f'_allreduce{DimSpec(vocab_tp_dim)}'
105
+ output0_comm_action = CommSpec(
106
+ comm_pattern='all_reduce',
107
+ sharding_spec=sharding_spec_mapping['output0'],
108
+ logical_process_axis=[vocab_tp_dim],
109
+ )
110
+ communication_action_mapping[
111
+ 'output0'] = output0_comm_action
112
+ sharding_strategy = self._get_sharding_strategy(
113
+ name=name,
114
+ sharding_spec_mapping=sharding_spec_mapping,
115
+ communication_action_mapping=communication_action_mapping)
116
+ strategies_vector.append(sharding_strategy)
117
+
118
+ input_id, indice_id = self.input_id, self.indice_id
119
+ strategies_vector = StrategiesVector(self)
120
+ input_size = len(self.op_data[f'input{input_id}'].shape)
121
+ indice_size = len(self.op_data[f'input{indice_id}'].shape)
122
+ output_dim = input_size + indice_size - 1 - self.num_elementwise_dims
123
+ for strategy in self.predecessor_nodes[input_id].strategies_vector:
124
+ # current node's local name input0 -> global name xxx
125
+ global_input_name = self.op_data[f'input{input_id}'].name
126
+ # global name xxx -> pre node local output name
127
+ prenode_local_name = self.predecessor_nodes[
128
+ input_id].global_to_local_op_name[global_input_name]
129
+ input_dim_partition_dict = copy.deepcopy(
130
+ strategy.sharding_specs[prenode_local_name].dim_partition_dict)
131
+
132
+ vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None)
133
+
134
+ input_mesh_dims = []
135
+ for dim, mesh_dims in input_dim_partition_dict.items():
136
+ input_mesh_dims += mesh_dims
137
+ input_mesh_dims = set(input_mesh_dims)
138
+
139
+ for idx_strategy in self.predecessor_nodes[
140
+ indice_id].strategies_vector:
141
+ # current node's local name input0 -> global name xxx
142
+ global_indice_name = self.op_data[f'input{indice_id}'].name
143
+ # global name xxx -> pre node local output name
144
+ prenode_local_name = self.predecessor_nodes[
145
+ indice_id].global_to_local_op_name[global_indice_name]
146
+ indice_dim_partition_dict = copy.deepcopy(
147
+ idx_strategy.sharding_specs[prenode_local_name].
148
+ dim_partition_dict)
149
+
150
+ for dim, indice_mesh_dims in idx_strategy.sharding_specs[
151
+ prenode_local_name].dim_partition_dict.items():
152
+ for indice_mesh_dim in indice_mesh_dims:
153
+ if indice_mesh_dim in input_mesh_dims:
154
+ indice_dim_partition_dict.pop(dim)
155
+ break
156
+
157
+ out_partition_dict = {}
158
+
159
+ for dim in range(output_dim):
160
+ if dim < self.axis:
161
+ if dim in input_dim_partition_dict:
162
+ out_partition_dict[dim] = \
163
+ input_dim_partition_dict[dim]
164
+ elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims:
165
+ indice_dim = dim - self.axis + self.num_elementwise_dims
166
+ if indice_dim in indice_dim_partition_dict:
167
+ out_partition_dict[dim] = \
168
+ indice_dim_partition_dict[indice_dim]
169
+ else:
170
+ input_dim = dim - (indice_size -
171
+ self.num_elementwise_dims) + 1
172
+ if input_dim in input_dim_partition_dict:
173
+ out_partition_dict[dim] = \
174
+ input_dim_partition_dict[input_dim]
175
+
176
+ dim_partition_dict_mapping = {
177
+ f"input{input_id}": input_dim_partition_dict,
178
+ f"input{indice_id}": indice_dim_partition_dict,
179
+ "output0": out_partition_dict,
180
+ }
181
+ add_sharding_strategy(dim_partition_dict_mapping)
182
+
183
+ if self.support_vocab_tp and vocab_tp_dim is not None:
184
+ vocab_tp_dim_partition_dict = {
185
+ **input_dim_partition_dict,
186
+ self.axis: vocab_tp_dim,
187
+ }
188
+ dim_partition_dict_mapping = {
189
+ f"input{input_id}": vocab_tp_dim_partition_dict,
190
+ f"input{indice_id}": indice_dim_partition_dict,
191
+ "output0": out_partition_dict,
192
+ }
193
+ add_sharding_strategy(dim_partition_dict_mapping,
194
+ vocab_tp_dim)
195
+
196
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Identity(Node):
8
+
9
+ def _update_memory_cost(self, strategies):
10
+ if not self.is_fake:
11
+ super()._update_memory_cost(strategies)
12
+ else:
13
+ # fake nodes for building block/PP connection
14
+ pass
15
+
16
+ def _collect_strategies(self, device_mesh):
17
+ dim_partition_list = []
18
+ dim_size = len(self.op_data['input0'].shape)
19
+ dim_partition_list.append({})
20
+ dim_partition_list.extend(
21
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
24
+ dim_partition_list.extend(
25
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
26
+ dim_partition_list.extend(
27
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
28
+ strategies_vector = StrategiesVector(self)
29
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
30
+ for dim_partition_dict in dim_partition_list:
31
+ in0_partition_dict = dim_partition_dict
32
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
33
+ dim_partition_dict_mapping = {
34
+ "input0": in0_partition_dict,
35
+ "output0": out_partition_dict,
36
+ }
37
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
38
+ dim_partition_dict_mapping, device_mesh)
39
+ if 0 == len(sharding_spec_mapping):
40
+ continue
41
+ name = '{} = <identity op> {}'.format(
42
+ sharding_spec_mapping['output0'].sharding_sequence,
43
+ sharding_spec_mapping['input0'].sharding_sequence)
44
+ sharding_strategy = self._get_sharding_strategy(
45
+ name=name,
46
+ sharding_spec_mapping=sharding_spec_mapping,
47
+ communication_action_mapping={})
48
+ strategies_vector.append(sharding_strategy)
49
+ return strategies_vector
50
+
51
+ def _profile_sharding_cost(self, strategy, device_mesh):
52
+ # if same spec id is not 0, identify node is used as same spec id node
53
+ if self.same_spec_id == -1:
54
+ return super()._profile_sharding_cost(strategy, device_mesh)
55
+ else:
56
+ return 0.0
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class InputNode(Node):
6
+
7
+ def _update_memory_cost(self, strategies):
8
+ for strategy in strategies:
9
+ if not self.no_memory_footprint:
10
+ strategy.const_memory_footprint = strategy.sharding_specs[
11
+ 'output0'].get_max_sharded_size_per_device()
12
+
13
+ def __init__(self, tensor):
14
+ self._layer = None
15
+ self.is_shape_io = False
16
+ self._inputs = []
17
+ self._outputs = []
18
+ self.predecessor_nodes = []
19
+ self.predecessor_nodes_out_index = {}
20
+ self.successor_nodes = []
21
+ self.op_data = {}
22
+ self.global_to_local_op_name = {}
23
+ self.is_replicated = tensor.attrs.get("is_replicated", False)
24
+ self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
25
+ self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
26
+ False)
27
+ self.building_block_id = -1
28
+ self.cost_level = -1
29
+ self.stage_type = None
30
+ self.in_start_block = None
31
+ self.in_end_block = None
32
+ self.in_slowest_block = None
33
+ output = tensor.copy()
34
+ self._outputs.append(output)
35
+ self.op_data['output0'] = output
36
+ self.global_to_local_op_name[output.name] = 'output0'
37
+
38
+ self.sharding_weight = 1.0
39
+ self.resharding_weight = 1.0
40
+ self.pipeline_weight = 0
41
+ self.node_name = tensor.name
42
+ self.node_type = 'input_node'
43
+ self.num_inputs = 0
44
+ self.num_outputs = 1
45
+ self.dtype = tensor.dtype
46
+ self.strategies_vector = []
47
+ self.node_runtime_profiler = None
48
+
49
+ def _collect_strategies(self, device_mesh):
50
+ dim_partition_list = []
51
+ dim_size = len(self.op_data['output0'].shape)
52
+ dim_partition_list.append({})
53
+ dim_partition_list.extend(
54
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
55
+ dim_partition_list.extend(
56
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
57
+ dim_partition_list.extend(
58
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
59
+ dim_partition_list.extend(
60
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
61
+
62
+ strategies_vector = StrategiesVector(self)
63
+ for dim_partition_dict in dim_partition_list:
64
+ dim_partition_dict_mapping = {'output0': dim_partition_dict}
65
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
66
+ dim_partition_dict_mapping, device_mesh)
67
+ if 0 == len(sharding_spec_mapping):
68
+ continue
69
+ sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
70
+ sharding_strategy = self._get_sharding_strategy(
71
+ name=f'input-op {sharding_seq}',
72
+ sharding_spec_mapping=sharding_spec_mapping,
73
+ communication_action_mapping={})
74
+ strategies_vector.append(sharding_strategy)
75
+
76
+ return strategies_vector
77
+
78
+ def _profile_sharding_cost(self, strategy, device_mesh):
79
+ return 0.0
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import operator
3
+ from functools import reduce
4
+
5
+ import tensorrt as trt
6
+
7
+ from ..device_mesh import LogicalDeviceMesh
8
+ from ..utils import get_builder_flags
9
+ from .comm_spec import CommSpec
10
+ from .node import Node
11
+ from .sharding_spec import DimSpec
12
+ from .sharding_strategy import StrategiesVector
13
+
14
+
15
+ class MatrixMultiply(Node):
16
+
17
+ def __init__(self, layer):
18
+ super().__init__(layer)
19
+ layer.to_subclass()
20
+ batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
21
+ self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
22
+ self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE
23
+ self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE
24
+ self.num_out_dims = len(self.get_output(0).shape)
25
+ dtypes_str = [
26
+ self.get_input(0).dtype_str,
27
+ self.get_input(1).dtype_str,
28
+ self.get_output(0).dtype_str
29
+ ]
30
+ dtypes_size = [
31
+ self.get_input(0).dtype_size,
32
+ self.get_input(1).dtype_size,
33
+ self.get_output(0).dtype_size
34
+ ]
35
+ min_idx = dtypes_size.index(min(dtypes_size))
36
+ self.dtype = dtypes_str[min_idx]
37
+ layer.to_base_class()
38
+
39
+ def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh):
40
+ in0_split_dim = -1 if self.op0_transpose else -2
41
+ in1_split_dim = -2 if self.op1_transpose else -1
42
+ name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = '
43
+ f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}')
44
+ dim_partition_dict_mapping = {
45
+ "input0": {
46
+ in0_split_dim: mesh_dim_0
47
+ },
48
+ "input1": {
49
+ in1_split_dim: mesh_dim_1
50
+ },
51
+ "output0": {
52
+ -2: mesh_dim_0,
53
+ -1: mesh_dim_1
54
+ },
55
+ }
56
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
57
+ dim_partition_dict_mapping, device_mesh)
58
+ if len(sharding_spec_mapping) == 0:
59
+ return None
60
+ strategy = self._get_sharding_strategy(name = name, \
61
+ sharding_spec_mapping = sharding_spec_mapping, \
62
+ communication_action_mapping = {})
63
+ return strategy
64
+
65
+ def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
66
+ device_mesh):
67
+ # handle the case SR = SS x SR
68
+ name = (
69
+ f'{DimSpec(mesh_dim_0)}R = '
70
+ f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R'
71
+ f'_allreduce{DimSpec(mesh_dim_1)}')
72
+ in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
73
+ in1_split_dim = -1 if self.op1_transpose else -2
74
+ # get sharding spec mapping
75
+ dim_partition_dict_mapping = {
76
+ "input0": {
77
+ in0_split_dim[0]: mesh_dim_0,
78
+ in0_split_dim[1]: mesh_dim_1
79
+ },
80
+ "input1": {
81
+ in1_split_dim: mesh_dim_1
82
+ },
83
+ "output0": {
84
+ -2: mesh_dim_0
85
+ },
86
+ }
87
+
88
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
89
+ dim_partition_dict_mapping, device_mesh)
90
+ if len(sharding_spec_mapping) == 0:
91
+ return None
92
+ # get communication action mapping
93
+ communication_action_mapping = {}
94
+ output0_comm_action = CommSpec(
95
+ comm_pattern='all_reduce',
96
+ sharding_spec=sharding_spec_mapping['output0'],
97
+ logical_process_axis=[mesh_dim_1],
98
+ )
99
+ communication_action_mapping['output0'] = output0_comm_action
100
+ return self._get_sharding_strategy(
101
+ name=name,
102
+ sharding_spec_mapping=sharding_spec_mapping,
103
+ communication_action_mapping=communication_action_mapping)
104
+
105
+ def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec,
106
+ dim_partition_dict_mapping, device_mesh):
107
+ output0_comm_action = CommSpec(
108
+ comm_pattern='reduce_scatter',
109
+ sharding_spec=src_spec,
110
+ shard_dim=[rs_dim],
111
+ logical_process_axis=[rs_mesh_dim],
112
+ )
113
+ rs_out_partition_dict_mapping = copy.deepcopy(
114
+ dim_partition_dict_mapping)
115
+ rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim
116
+ rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
117
+ rs_out_partition_dict_mapping, device_mesh)
118
+ if len(rs_out_sharding_spec_mapping) == 0:
119
+ return None
120
+
121
+ communication_action_mapping = {}
122
+ communication_action_mapping['output0'] = output0_comm_action
123
+ return self._get_sharding_strategy(
124
+ name=name,
125
+ sharding_spec_mapping=rs_out_sharding_spec_mapping,
126
+ communication_action_mapping=communication_action_mapping)
127
+
128
+ def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
129
+ device_mesh):
130
+ # handle the case SS = SS x SR -> reduce_scatter
131
+ in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
132
+ in1_split_dim = -1 if self.op1_transpose else -2
133
+ # get sharding spec mapping
134
+ dim_partition_dict_mapping = {
135
+ "input0": {
136
+ in0_split_dim[0]: mesh_dim_0,
137
+ in0_split_dim[1]: mesh_dim_1
138
+ },
139
+ "input1": {
140
+ in1_split_dim: mesh_dim_1
141
+ },
142
+ "output0": {
143
+ -2: mesh_dim_0,
144
+ },
145
+ }
146
+ mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
147
+ dim_partition_dict_mapping, device_mesh)
148
+ if len(mm_out_sharding_spec_mapping) == 0:
149
+ return []
150
+ strategies = []
151
+ for rs_dim in range(self.num_out_dims):
152
+ if rs_dim != self.num_out_dims - 2:
153
+ name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
154
+ 'R'
155
+ ] * self.num_out_dims, ['R'] * self.num_out_dims
156
+ name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str(
157
+ DimSpec(mesh_dim_1))
158
+ name_in1[-2] = str(DimSpec(mesh_dim_1))
159
+ name_out0[-2], name_out0[rs_dim] = str(
160
+ DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1))
161
+ name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
162
+ name_in1), ', '.join(name_out0)
163
+ name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
164
+ f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}')
165
+ ret = self._split_both_contract_rs(
166
+ name, rs_dim, mesh_dim_1,
167
+ mm_out_sharding_spec_mapping['output0'],
168
+ dim_partition_dict_mapping, device_mesh)
169
+ if ret:
170
+ strategies.append(ret)
171
+ return strategies
172
+
173
+ def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
174
+ device_mesh):
175
+ name = (
176
+ f'R{DimSpec(mesh_dim_1)} = '
177
+ f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}'
178
+ f'_allreduce{DimSpec(mesh_dim_0)}')
179
+ in0_split_dim = -2 if self.op0_transpose else -1
180
+ in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
181
+ # get sharding specs
182
+ dim_partition_dict_mapping = {
183
+ "input0": {
184
+ in0_split_dim: mesh_dim_0
185
+ },
186
+ "input1": {
187
+ in1_split_dim[0]: mesh_dim_0,
188
+ in1_split_dim[1]: mesh_dim_1
189
+ },
190
+ "output0": {
191
+ -1: mesh_dim_1
192
+ },
193
+ }
194
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
195
+ dim_partition_dict_mapping, device_mesh)
196
+ if len(sharding_spec_mapping) == 0:
197
+ return None
198
+ # get communication actions
199
+ communication_action_mapping = {}
200
+ output0_comm_action = CommSpec(
201
+ comm_pattern='all_reduce',
202
+ sharding_spec=sharding_spec_mapping['output0'],
203
+ logical_process_axis=[mesh_dim_0],
204
+ )
205
+ communication_action_mapping['output0'] = output0_comm_action
206
+ return self._get_sharding_strategy(
207
+ name=name,
208
+ sharding_spec_mapping=sharding_spec_mapping,
209
+ communication_action_mapping=communication_action_mapping)
210
+
211
+ def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
212
+ device_mesh):
213
+ in0_split_dim = -2 if self.op0_transpose else -1
214
+ in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
215
+ # get sharding specs
216
+ dim_partition_dict_mapping = {
217
+ "input0": {
218
+ in0_split_dim: mesh_dim_0
219
+ },
220
+ "input1": {
221
+ in1_split_dim[0]: mesh_dim_0,
222
+ in1_split_dim[1]: mesh_dim_1
223
+ },
224
+ "output0": {
225
+ -1: mesh_dim_1
226
+ },
227
+ }
228
+ mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
229
+ dim_partition_dict_mapping, device_mesh)
230
+ if len(mm_out_sharding_spec_mapping) == 0:
231
+ return []
232
+ strategies = []
233
+ for rs_dim in range(self.num_out_dims):
234
+ if rs_dim != self.num_out_dims - 1:
235
+ name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
236
+ 'R'
237
+ ] * self.num_out_dims, ['R'] * self.num_out_dims
238
+ name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str(
239
+ DimSpec(mesh_dim_1))
240
+ name_in0[-1] = str(DimSpec(mesh_dim_0))
241
+ name_out0[-1], name_out0[rs_dim] = str(
242
+ DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0))
243
+ name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
244
+ name_in1), ', '.join(name_out0)
245
+ name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
246
+ f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}')
247
+ ret = self._split_both_contract_rs(
248
+ name, rs_dim, mesh_dim_0,
249
+ mm_out_sharding_spec_mapping['output0'],
250
+ dim_partition_dict_mapping, device_mesh)
251
+ if ret:
252
+ strategies.append(ret)
253
+ return strategies
254
+
255
+ def _recompute_split_both_contract(self, mesh_dim, device_mesh):
256
+ name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
257
+ f'_allreduce{DimSpec(mesh_dim)}')
258
+ in0_split_dim = -2 if self.op0_transpose else -1
259
+ in1_split_dim = -1 if self.op1_transpose else -2
260
+ dim_partition_dict_mapping = {
261
+ "input0": {
262
+ in0_split_dim: mesh_dim
263
+ },
264
+ "input1": {
265
+ in1_split_dim: mesh_dim
266
+ },
267
+ "output0": {},
268
+ }
269
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
270
+ dim_partition_dict_mapping, device_mesh)
271
+ if len(sharding_spec_mapping) == 0:
272
+ return None
273
+
274
+ # get communication action
275
+ communication_action_mapping = {}
276
+ output0_comm_action = CommSpec(
277
+ comm_pattern='all_reduce',
278
+ sharding_spec=sharding_spec_mapping['output0'],
279
+ logical_process_axis=[mesh_dim],
280
+ )
281
+ communication_action_mapping['output0'] = output0_comm_action
282
+ return self._get_sharding_strategy(
283
+ name=name,
284
+ sharding_spec_mapping=sharding_spec_mapping,
285
+ communication_action_mapping=communication_action_mapping)
286
+
287
+ def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh):
288
+ name = (f'{DimSpec(mesh_dim)}R = '
289
+ f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
290
+ f'_reducescatter0_{DimSpec(mesh_dim)}')
291
+ in0_split_dim = -2 if self.op0_transpose else -1
292
+ in1_split_dim = -1 if self.op1_transpose else -2
293
+ dim_partition_dict_mapping = {
294
+ "input0": {
295
+ in0_split_dim: mesh_dim
296
+ },
297
+ "input1": {
298
+ in1_split_dim: mesh_dim
299
+ },
300
+ "output0": {},
301
+ }
302
+ mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
303
+ dim_partition_dict_mapping, device_mesh)
304
+ if len(mm_out_sharding_spec_mapping) == 0:
305
+ return []
306
+
307
+ strategies = []
308
+ for rs_dim in range(self.num_out_dims):
309
+ name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
310
+ 'R'
311
+ ] * self.num_out_dims, ['R'] * self.num_out_dims
312
+ name_in0[-1], name_in1[-2], name_out0[rs_dim] = str(
313
+ DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str(
314
+ DimSpec(mesh_dim))
315
+ name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
316
+ name_in1), ', '.join(name_out0)
317
+ name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}'
318
+ ret = self._split_both_contract_rs(
319
+ name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'],
320
+ dim_partition_dict_mapping, device_mesh)
321
+ if ret:
322
+ strategies.append(ret)
323
+ return strategies
324
+
325
+ def _split_rhs_space_only(self, mesh_dim, device_mesh):
326
+ name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}'
327
+ in1_split_dim = -2 if self.op1_transpose else -1
328
+ # get sharding spec
329
+ dim_partition_dict_mapping = {
330
+ "input0": {},
331
+ "input1": {
332
+ in1_split_dim: mesh_dim
333
+ },
334
+ "output0": {
335
+ -1: mesh_dim
336
+ },
337
+ }
338
+ # We don't have to do anything special for bias here, because
339
+ # the bias is already the same sharding spec as the output0.
340
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
341
+ dim_partition_dict_mapping, device_mesh)
342
+ if len(sharding_spec_mapping) == 0:
343
+ return None
344
+ return self._get_sharding_strategy(
345
+ name=name,
346
+ sharding_spec_mapping=sharding_spec_mapping,
347
+ communication_action_mapping={})
348
+
349
+ def _split_lhs_space_only(self, mesh_dim, device_mesh):
350
+ name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR'
351
+ in0_split_dim = -1 if self.op0_transpose else -2
352
+ # get sharding spec
353
+ dim_partition_dict_mapping = {
354
+ "input0": {
355
+ in0_split_dim: mesh_dim
356
+ },
357
+ "input1": {},
358
+ "output0": {
359
+ -2: mesh_dim
360
+ },
361
+ }
362
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
363
+ dim_partition_dict_mapping, device_mesh)
364
+ if len(sharding_spec_mapping) == 0:
365
+ return None
366
+ return self._get_sharding_strategy(
367
+ name=name,
368
+ sharding_spec_mapping=sharding_spec_mapping,
369
+ communication_action_mapping={})
370
+
371
+ def _non_split(self, device_mesh):
372
+ name = 'RR = RR x RR'
373
+ # get sharding spec
374
+ dim_partition_dict_mapping = {
375
+ "input0": {},
376
+ "input1": {},
377
+ "output0": {},
378
+ }
379
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
380
+ dim_partition_dict_mapping, device_mesh)
381
+ if len(sharding_spec_mapping) == 0:
382
+ return None
383
+ return self._get_sharding_strategy(
384
+ name=name,
385
+ sharding_spec_mapping=sharding_spec_mapping,
386
+ communication_action_mapping={})
387
+
388
+ def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh):
389
+ name = (
390
+ f'{DimSpec(mesh_dim)}b{batch_dim}RR = '
391
+ f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR'
392
+ )
393
+ in0_data = self.op_data['input0']
394
+ in1_data = self.op_data['input1']
395
+
396
+ batch_partition_dict = {batch_dim: mesh_dim}
397
+ in0_parition_dict = self._recover_bcast_partition_dict(
398
+ batch_partition_dict, in0_data)
399
+ in1_parition_dict = self._recover_bcast_partition_dict(
400
+ batch_partition_dict, in1_data)
401
+ out_partition_dict = {batch_dim: mesh_dim}
402
+ # TODO:[KDuan] Double check if MatrixMultiplication's output has bcast in dim
403
+ dim_partition_dict_mapping = {
404
+ "input0": in0_parition_dict,
405
+ "input1": in1_parition_dict,
406
+ "output0": out_partition_dict,
407
+ }
408
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
409
+ dim_partition_dict_mapping, device_mesh)
410
+ if len(sharding_spec_mapping) == 0:
411
+ return None
412
+ return self._get_sharding_strategy(
413
+ name=name,
414
+ sharding_spec_mapping=sharding_spec_mapping,
415
+ communication_action_mapping={})
416
+
417
+ def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0,
418
+ mesh_dim1, device_mesh):
419
+ name = (
420
+ f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = '
421
+ f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR'
422
+ )
423
+ in0_data = self.op_data['input0']
424
+ in1_data = self.op_data['input1']
425
+
426
+ in0_parition_dict = {}
427
+ if batch_dim0 not in in0_data.attrs["broadcast_dims"]:
428
+ in0_parition_dict[batch_dim0] = mesh_dim0
429
+ if batch_dim1 not in in0_data.attrs["broadcast_dims"]:
430
+ in0_parition_dict[batch_dim1] = mesh_dim1
431
+
432
+ in1_parition_dict = {}
433
+ if batch_dim0 not in in1_data.attrs["broadcast_dims"]:
434
+ in1_parition_dict[batch_dim0] = mesh_dim0
435
+ if batch_dim1 not in in1_data.attrs["broadcast_dims"]:
436
+ in1_parition_dict[batch_dim1] = mesh_dim1
437
+
438
+ batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
439
+ in0_parition_dict = self._recover_bcast_partition_dict(
440
+ batch_partition_dict, in0_data)
441
+ in1_parition_dict = self._recover_bcast_partition_dict(
442
+ batch_partition_dict, in1_data)
443
+ out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
444
+ dim_partition_dict_mapping = {
445
+ "input0": in0_parition_dict,
446
+ "input1": in1_parition_dict,
447
+ "output0": out_partition_dict,
448
+ }
449
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
450
+ dim_partition_dict_mapping, device_mesh)
451
+ if len(sharding_spec_mapping) == 0:
452
+ return None
453
+ return self._get_sharding_strategy(
454
+ name=name,
455
+ sharding_spec_mapping=sharding_spec_mapping,
456
+ communication_action_mapping={})
457
+
458
+ def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
459
+ device_mesh):
460
+
461
+ name = (
462
+ f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = '
463
+ f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR'
464
+ )
465
+ in0_data = self.op_data['input0']
466
+ in1_data = self.op_data['input1']
467
+ in0_parition_dict = {batch_dim: mesh_dim0}
468
+ in1_parition_dict = {batch_dim: mesh_dim0}
469
+ in0_lhs_split_dim = -1 if self.op0_transpose else -2
470
+ in0_parition_dict[in0_lhs_split_dim] = mesh_dim1
471
+
472
+ in0_parition_dict = self._recover_bcast_partition_dict(
473
+ in0_parition_dict, in0_data)
474
+ in1_parition_dict = self._recover_bcast_partition_dict(
475
+ in1_parition_dict, in1_data)
476
+ out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1}
477
+
478
+ dim_partition_dict_mapping = {
479
+ "input0": in0_parition_dict,
480
+ "input1": in1_parition_dict,
481
+ "output0": out_partition_dict,
482
+ }
483
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
484
+ dim_partition_dict_mapping, device_mesh)
485
+ if len(sharding_spec_mapping) == 0:
486
+ return None
487
+ return self._get_sharding_strategy(
488
+ name=name,
489
+ sharding_spec_mapping=sharding_spec_mapping,
490
+ communication_action_mapping={})
491
+
492
+ def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
493
+ device_mesh):
494
+
495
+ name = (
496
+ f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = '
497
+ f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}'
498
+ )
499
+ in0_data = self.op_data['input0']
500
+ in1_data = self.op_data['input1']
501
+ in0_parition_dict = {batch_dim: mesh_dim0}
502
+ in1_parition_dict = {batch_dim: mesh_dim0}
503
+
504
+ in1_rhs_split_dim = -2 if self.op1_transpose else -1
505
+ in1_parition_dict[in1_rhs_split_dim] = mesh_dim1
506
+
507
+ in0_parition_dict = self._recover_bcast_partition_dict(
508
+ in0_parition_dict, in0_data)
509
+ in1_parition_dict = self._recover_bcast_partition_dict(
510
+ in1_parition_dict, in1_data)
511
+ out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1}
512
+ dim_partition_dict_mapping = {
513
+ "input0": in0_parition_dict,
514
+ "input1": in1_parition_dict,
515
+ "output0": out_partition_dict,
516
+ }
517
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
518
+ dim_partition_dict_mapping, device_mesh)
519
+ if len(sharding_spec_mapping) == 0:
520
+ return None
521
+ return self._get_sharding_strategy(
522
+ name=name,
523
+ sharding_spec_mapping=sharding_spec_mapping,
524
+ communication_action_mapping={})
525
+
526
+ def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1,
527
+ device_mesh):
528
+
529
+ name = (
530
+ f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
531
+ f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
532
+ f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
533
+ )
534
+ in0_data = self.op_data['input0']
535
+ in1_data = self.op_data['input1']
536
+ in0_parition_dict = {batch_dim: mesh_dim0}
537
+ in1_parition_dict = {batch_dim: mesh_dim0}
538
+
539
+ in0_contract_dim = -2 if self.op0_transpose else -1
540
+ in1_contract_dim = -1 if self.op1_transpose else -2
541
+ in0_parition_dict[in0_contract_dim] = mesh_dim1
542
+ in1_parition_dict[in1_contract_dim] = mesh_dim1
543
+
544
+ in0_parition_dict = self._recover_bcast_partition_dict(
545
+ in0_parition_dict, in0_data)
546
+ in1_parition_dict = self._recover_bcast_partition_dict(
547
+ in1_parition_dict, in1_data)
548
+ out_partition_dict = {batch_dim: mesh_dim0}
549
+ dim_partition_dict_mapping = {
550
+ "input0": in0_parition_dict,
551
+ "input1": in1_parition_dict,
552
+ "output0": out_partition_dict,
553
+ }
554
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
555
+ dim_partition_dict_mapping, device_mesh)
556
+ if len(sharding_spec_mapping) == 0:
557
+ return None
558
+
559
+ # get communication actions
560
+ communication_action_mapping = {}
561
+ output0_comm_action = CommSpec(
562
+ comm_pattern='all_reduce',
563
+ sharding_spec=sharding_spec_mapping['output0'],
564
+ logical_process_axis=[mesh_dim1],
565
+ )
566
+ communication_action_mapping['output0'] = output0_comm_action
567
+ return self._get_sharding_strategy(
568
+ name=name,
569
+ sharding_spec_mapping=sharding_spec_mapping,
570
+ communication_action_mapping=communication_action_mapping)
571
+
572
+ def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1,
573
+ device_mesh):
574
+
575
+ name = (
576
+ f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
577
+ f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
578
+ f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
579
+ )
580
+ in0_data = self.op_data['input0']
581
+ in1_data = self.op_data['input1']
582
+ in0_parition_dict = {batch_dim: mesh_dim0}
583
+ in1_parition_dict = {batch_dim: mesh_dim0}
584
+
585
+ in0_contract_dim = -2 if self.op0_transpose else -1
586
+ in1_contract_dim = -1 if self.op1_transpose else -2
587
+ in0_parition_dict[in0_contract_dim] = mesh_dim1
588
+ in1_parition_dict[in1_contract_dim] = mesh_dim1
589
+
590
+ in0_parition_dict = self._recover_bcast_partition_dict(
591
+ in0_parition_dict, in0_data)
592
+ in1_parition_dict = self._recover_bcast_partition_dict(
593
+ in1_parition_dict, in1_data)
594
+ out_partition_dict = {batch_dim: mesh_dim0}
595
+ dim_partition_dict_mapping = {
596
+ "input0": in0_parition_dict,
597
+ "input1": in1_parition_dict,
598
+ "output0": out_partition_dict,
599
+ }
600
+ mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
601
+ dim_partition_dict_mapping, device_mesh)
602
+ if len(mm_out_sharding_spec_mapping) == 0:
603
+ return []
604
+
605
+ strategies = []
606
+ for rs_dim in range(self.num_out_dims):
607
+ if rs_dim != batch_dim:
608
+ name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
609
+ 'R'
610
+ ] * self.num_out_dims, ['R'] * self.num_out_dims
611
+ name_in0[batch_dim], name_in0[-1] = str(
612
+ DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
613
+ name_in1[batch_dim], name_in1[-2] = str(
614
+ DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
615
+ name_in1[batch_dim], name_out0[rs_dim] = str(
616
+ DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
617
+ name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
618
+ name_in1), ', '.join(name_out0)
619
+ name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}'
620
+ ret = self._split_both_contract_rs(
621
+ name, rs_dim, mesh_dim1,
622
+ mm_out_sharding_spec_mapping['output0'],
623
+ dim_partition_dict_mapping, device_mesh)
624
+ if ret:
625
+ strategies.append(ret)
626
+ return strategies
627
+
628
+ def _dp_strategies(self, device_mesh):
629
+ strategies = []
630
+ # S0R = S0R x RR
631
+ strategies.append(self._split_lhs_space_only([0], device_mesh))
632
+ # S1R = S1R x RR
633
+ strategies.append(self._split_lhs_space_only([1], device_mesh))
634
+ # S01R = S01R x RR
635
+ strategies.append(self._split_lhs_space_only([0, 1], device_mesh))
636
+ return strategies
637
+
638
+ def _tp_strategies(self, device_mesh: LogicalDeviceMesh):
639
+ strategies = []
640
+ # RR = RS x SR _ AR
641
+ strategies.append(self._recompute_split_both_contract([0], device_mesh))
642
+ strategies.append(self._recompute_split_both_contract([1], device_mesh))
643
+ strategies.append(
644
+ self._recompute_split_both_contract([0, 1], device_mesh))
645
+
646
+ if device_mesh.config.enable_reduce_scatter:
647
+ # RS x SR _ reduce scatter
648
+ strategies.extend(
649
+ self._recompute_split_both_contract_rs([0], device_mesh))
650
+ strategies.extend(
651
+ self._recompute_split_both_contract_rs([1], device_mesh))
652
+ strategies.extend(
653
+ self._recompute_split_both_contract_rs([0, 1], device_mesh))
654
+
655
+ # RS = RR x RS
656
+ strategies.append(self._split_rhs_space_only([0], device_mesh))
657
+ strategies.append(self._split_rhs_space_only([1], device_mesh))
658
+ strategies.append(self._split_rhs_space_only([0, 1], device_mesh))
659
+
660
+ # RS = RS x SS _ AR
661
+ strategies.append(
662
+ self._split_rhs_space_both_contract([0], [1], device_mesh))
663
+ strategies.append(
664
+ self._split_rhs_space_both_contract([1], [0], device_mesh))
665
+
666
+ if device_mesh.config.enable_reduce_scatter:
667
+ # RS x SS _ reduce scatter
668
+ strategies.extend(
669
+ self._split_rhs_space_both_contract_rs([0], [1], device_mesh))
670
+ strategies.extend(
671
+ self._split_rhs_space_both_contract_rs([1], [0], device_mesh))
672
+
673
+ return strategies
674
+
675
+ def _mix_strategies(self, device_mesh):
676
+ strategies = []
677
+
678
+ # SR = SS x SR_AR
679
+ strategies.append(
680
+ self._split_lhs_space_both_contract([0], [1], device_mesh))
681
+ strategies.append(
682
+ self._split_lhs_space_both_contract([1], [0], device_mesh))
683
+ if device_mesh.config.enable_reduce_scatter:
684
+ # RS x SS _ reduce scatter
685
+ strategies.extend(
686
+ self._split_lhs_space_both_contract_rs([0], [1], device_mesh))
687
+ strategies.extend(
688
+ self._split_lhs_space_both_contract_rs([1], [0], device_mesh))
689
+ # SS = SR x RS
690
+ strategies.append(self._split_lhs_space_rhs_space([0], [1],
691
+ device_mesh))
692
+ strategies.append(self._split_lhs_space_rhs_space([0], [1],
693
+ device_mesh))
694
+
695
+ # RR = RR x RR
696
+ strategies.append(self._non_split(device_mesh))
697
+ return strategies
698
+
699
+ def _bmm_strategies(self, device_mesh: LogicalDeviceMesh):
700
+ strategies = []
701
+ bmm_dim = len(self.op_data['output0'].shape)
702
+ if bmm_dim >= 3:
703
+ for batch_dim in range(0, bmm_dim - 2):
704
+ strategies.append(
705
+ self._split_one_batch_dim(batch_dim, [0], device_mesh))
706
+ strategies.append(
707
+ self._split_one_batch_dim(batch_dim, [1], device_mesh))
708
+ strategies.append(
709
+ self._split_one_batch_dim(batch_dim, [0, 1], device_mesh))
710
+
711
+ strategies.append(
712
+ self._split_batch_dim_lhs_space(batch_dim, [0], [1],
713
+ device_mesh))
714
+ strategies.append(
715
+ self._split_batch_dim_lhs_space(batch_dim, [1], [0],
716
+ device_mesh))
717
+
718
+ strategies.append(
719
+ self._split_batch_dim_rhs_space(batch_dim, [0], [1],
720
+ device_mesh))
721
+ strategies.append(
722
+ self._split_batch_dim_rhs_space(batch_dim, [1], [0],
723
+ device_mesh))
724
+
725
+ strategies.append(
726
+ self._split_batch_dim_both_contract(batch_dim, [0], [1],
727
+ device_mesh))
728
+ strategies.append(
729
+ self._split_batch_dim_both_contract(batch_dim, [1], [0],
730
+ device_mesh))
731
+ if device_mesh.config.enable_reduce_scatter:
732
+ strategies.extend(
733
+ self._split_batch_dim_both_contract_rs(
734
+ batch_dim, [0], [1], device_mesh))
735
+ strategies.extend(
736
+ self._split_batch_dim_both_contract_rs(
737
+ batch_dim, [1], [0], device_mesh))
738
+ if bmm_dim >= 4:
739
+ for batch_dim0 in range(0, bmm_dim - 2):
740
+ for batch_dim1 in range(0, bmm_dim - 2):
741
+ if batch_dim0 != batch_dim1:
742
+ strategies.append(
743
+ self._split_two_batch_dims(
744
+ batch_dim0, batch_dim1, [0], [1],
745
+ device_mesh))
746
+
747
+ return strategies
748
+
749
+ def _collect_strategies(self, device_mesh):
750
+ strategies_vector = StrategiesVector(self)
751
+ dp_strategies = self._dp_strategies(device_mesh)
752
+ tp_strategies = self._tp_strategies(device_mesh)
753
+ mix_strategies = self._mix_strategies(device_mesh)
754
+ bmm_strategies = self._bmm_strategies(device_mesh)
755
+ strategies_vector.extend(dp_strategies)
756
+ strategies_vector.extend(tp_strategies)
757
+ strategies_vector.extend(mix_strategies)
758
+ strategies_vector.extend(bmm_strategies)
759
+ return strategies_vector
760
+
761
+ def is_fp16(self):
762
+ builder_flags = get_builder_flags()
763
+ return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0
764
+
765
+ def _get_math_time(self, strategy, device_mesh):
766
+ shape_in0 = strategy.sharding_specs[
767
+ 'input0'].get_sharded_shape_per_device()
768
+ shape_out = strategy.sharding_specs[
769
+ 'output0'].get_sharded_shape_per_device()
770
+ m, n = shape_out[-2], shape_out[-1]
771
+ batches = shape_out[:-2]
772
+ k = shape_in0[-2] if self.op0_transpose else shape_in0[-1]
773
+ macs_shape = batches + [m, n, k]
774
+ macs = reduce(operator.mul, macs_shape, 1) * 2
775
+ config = device_mesh.config
776
+ cluster_info = device_mesh.cluster_info
777
+ dtype = self.dtype
778
+ # For fp16 matmul ops that use_fp32_acc=True.
779
+ # They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype.
780
+ if self.is_fp16() and self.dtype == "float32":
781
+ dtype = "float16"
782
+ math_throughput_tflops = getattr(cluster_info.math_throughput, dtype)
783
+ assert math_throughput_tflops != 0, \
784
+ "Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key)
785
+ math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency
786
+ return math_time
787
+
788
+ def _update_memory_cost(self, strategies):
789
+ super()._update_memory_cost(strategies)
790
+ # For fp16 matmul ops that use_fp32_acc=True.
791
+ # Their memory footprints are calculated based on fp32 IO tensors.
792
+ # Actually they will use fp16 IO tensors after fused.
793
+ # So we divide all the memory footprints by 2.
794
+ if self.is_fp16() and self.dtype == "float32":
795
+ for strategy in strategies:
796
+ strategy.inout_memory_footprint /= 2
797
+ strategy.peak_memory_footprint /= 2
798
+ strategy.comm_buff_memory_footprint /= 2
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/node.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ from ..config import CostModel
4
+ from ..device_mesh import LogicalDeviceMesh
5
+ from .comm_spec import CommSpec
6
+ from .sharding_spec import ShardingSpec
7
+ from .sharding_strategy import ShardingStrategy, StrategiesVector
8
+
9
+
10
+ class Node(ABC):
11
+
12
+ def __init__(self, layer):
13
+ self._layer = layer
14
+ self.is_shape_io = self._layer.is_shape_io
15
+ self._inputs = []
16
+ self._outputs = []
17
+ self.predecessor_nodes = []
18
+ self.predecessor_nodes_out_index = {}
19
+ self.successor_nodes = []
20
+ self.op_data = {}
21
+ self.global_to_local_op_name = {}
22
+ self.num_inputs = 0
23
+ self.is_replicated = layer.attrs.get("is_replicated", False)
24
+ self.same_spec_id = layer.attrs.get("same_spec_id", -1)
25
+ self.is_fake = self.same_spec_id != -1
26
+ self.building_block_id = layer.attrs.get("building_block_id", -1)
27
+ self.cost_level = -1
28
+ self.stage_type = layer.attrs.get("stage_type", None)
29
+ self.in_start_block = layer.attrs.get("in_start_block", False)
30
+ self.in_end_block = layer.attrs.get("in_end_block", False)
31
+ self.in_slowest_block = layer.attrs.get("in_slowest_block", False)
32
+ for i, input in enumerate(layer.inputs):
33
+ if input is None:
34
+ self._inputs.append(None)
35
+ self.op_data[f'input{i}'] = None
36
+ continue
37
+ input = input.copy()
38
+ input.attrs["broadcast_dims"] = []
39
+ self._inputs.append(input)
40
+ self.op_data[f'input{i}'] = input
41
+ self.global_to_local_op_name[input.name] = f'input{i}'
42
+
43
+ for i, output in enumerate(layer.outputs):
44
+ output = output.copy()
45
+ output.attrs["broadcast_dims"] = []
46
+ self._outputs.append(output)
47
+ self.op_data[f'output{i}'] = output
48
+ self.global_to_local_op_name[output.name] = f'output{i}'
49
+
50
+ self.sharding_weight = 1.0
51
+ self.resharding_weight = 1.0
52
+ self.pipeline_weight = 0
53
+ self.node_name = layer.name
54
+ self.node_type = 'normal_node'
55
+ self.num_inputs = layer.num_inputs
56
+ self.num_outputs = layer.num_outputs
57
+ self.dtype = layer.as_trt().precision
58
+ self.strategies_vector = []
59
+ self.node_runtime_profiler = None
60
+
61
+ def post_init(self, graph):
62
+ for input in self.inputs:
63
+ if input is None:
64
+ self.predecessor_nodes.append(None)
65
+ continue
66
+ if input.producer is None:
67
+ predecessor_node = graph.get_node(input.name)
68
+ self.predecessor_nodes.append(predecessor_node)
69
+ self.predecessor_nodes_out_index[predecessor_node] = 0
70
+ predecessor_node.successor_nodes.append(self)
71
+ else:
72
+ predecessor_node = graph.get_node(input.producer.name)
73
+ self.predecessor_nodes.append(predecessor_node)
74
+ self.predecessor_nodes_out_index[
75
+ predecessor_node] = input.output_index
76
+ predecessor_node.successor_nodes.append(self)
77
+
78
+ @property
79
+ def layer(self):
80
+ return self._layer
81
+
82
+ def get_input(self, index):
83
+ return self._inputs[index]
84
+
85
+ @property
86
+ def inputs(self):
87
+ return self._inputs
88
+
89
+ def get_output(self, index):
90
+ return self._outputs[index]
91
+
92
+ @property
93
+ def outputs(self):
94
+ return self._outputs
95
+
96
+ def collect_strategies(self, device_mesh):
97
+ strategies_vector = self._collect_strategies(device_mesh)
98
+ strategies_vector = self._post_process(strategies_vector)
99
+ self._update_sharding_cost(strategies_vector, device_mesh)
100
+ self.strategies_vector = strategies_vector
101
+ return self.strategies_vector
102
+
103
+ def _set_strategy(self, strategy, device_mesh):
104
+ strategies_vector = StrategiesVector(self)
105
+ if strategy is None:
106
+ dim_partition_dict_mapping = {}
107
+ for i in range(self.num_inputs):
108
+ dim_partition_dict_mapping[f'input{i}'] = {}
109
+ for i in range(self.num_outputs):
110
+ dim_partition_dict_mapping[f'output{i}'] = {}
111
+
112
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
113
+ dim_partition_dict_mapping, device_mesh)
114
+ assert 0 != len(
115
+ sharding_spec_mapping
116
+ ), f'failed to set default(all Replicate) strategy for node {self.node_name}'
117
+ name = 'RRs'
118
+ sharding_strategy = self._get_sharding_strategy(
119
+ name=name,
120
+ sharding_spec_mapping=sharding_spec_mapping,
121
+ communication_action_mapping={})
122
+ strategies_vector.append(sharding_strategy)
123
+
124
+ else:
125
+ sharding_specs_map = strategy.sharding_specs
126
+ comm_specs_map = strategy.communication_actions
127
+ dim_partition_dict_mapping = {}
128
+ for op_name, sharding_spec in sharding_specs_map.items():
129
+ dim_partition_dict_mapping[
130
+ op_name] = sharding_spec.dim_partition_dict
131
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
132
+ dim_partition_dict_mapping, device_mesh)
133
+ assert 0 != len(
134
+ sharding_spec_mapping
135
+ ), f'failed to set strategy for node {self.node_name}'
136
+ comm_specs_mapping = {}
137
+ if len(comm_specs_map) > 0:
138
+ for op_name, comm_spec in comm_specs_map.items():
139
+ comm_specs_mapping[op_name] = CommSpec(
140
+ comm_pattern=comm_spec.comm_pattern,
141
+ sharding_spec=sharding_spec_mapping[op_name],
142
+ logical_process_axis=comm_spec.logical_process_axis,
143
+ )
144
+ strategies_vector.append(
145
+ self._get_sharding_strategy(
146
+ name=strategy.name,
147
+ sharding_spec_mapping=sharding_spec_mapping,
148
+ communication_action_mapping=comm_specs_mapping))
149
+ return strategies_vector
150
+
151
+ def set_strategy(self, strategy, device_mesh):
152
+ strategies_vector = self._set_strategy(strategy, device_mesh)
153
+ strategies_vector = self._post_process(strategies_vector)
154
+ self._update_sharding_cost(strategies_vector, device_mesh)
155
+ self.strategies_vector = strategies_vector
156
+ return self.strategies_vector
157
+
158
+ def update_resharding_cost(self):
159
+ self._update_resharding_cost(self.strategies_vector)
160
+ return self.strategies_vector
161
+
162
+ def _to_sharding_spec_mapping(self, dim_partition_dict_mapping,
163
+ device_mesh):
164
+ results = {}
165
+ for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items(
166
+ ):
167
+ if op_data_name in self.op_data:
168
+ op_data = self.op_data[op_data_name]
169
+
170
+ def _to_sharding_spec(op_data, dim_partition_dict):
171
+ sharding_spec = ShardingSpec(
172
+ device_mesh,
173
+ op_data.dtype_str_size, [*op_data.shape],
174
+ [*op_data.max_shape], [*op_data.raw_shape],
175
+ dim_partition_dict=dim_partition_dict)
176
+ if sharding_spec.sanity_check():
177
+ return sharding_spec
178
+ else:
179
+ return None
180
+
181
+ sharding_spec = _to_sharding_spec(op_data, dim_partition_dict)
182
+ if sharding_spec:
183
+ results[op_data_name] = sharding_spec
184
+ else:
185
+ return {}
186
+ return results
187
+
188
+ def _get_sharding_strategy(self, name, sharding_spec_mapping,
189
+ communication_action_mapping):
190
+ return ShardingStrategy(
191
+ name=name,
192
+ sharding_specs=sharding_spec_mapping,
193
+ communication_actions=communication_action_mapping,
194
+ )
195
+
196
+ def _remove_duplicated_strategy(self, strategies_vector):
197
+ name_checklist = []
198
+ remove_list = []
199
+ for strategy in strategies_vector:
200
+ if strategy.name not in name_checklist:
201
+ name_checklist.append(strategy.name)
202
+ else:
203
+ remove_list.append(strategy)
204
+ for strategy in remove_list:
205
+ strategies_vector.remove(strategy)
206
+
207
+ def _post_process(self, strategies_vector):
208
+ # TODO:[KDuan] deal with transpose and dimension 1 problem in ClossalAI, which have been processed before
209
+ for i in range(len(strategies_vector) - 1, -1, -1):
210
+ if strategies_vector[i] is None:
211
+ strategies_vector.pop(i)
212
+
213
+ self._remove_duplicated_strategy(strategies_vector)
214
+ return strategies_vector
215
+
216
+ def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh):
217
+ elapsed_time = self.node_runtime_profiler.runtime_profile(
218
+ self.layer, {}, {}, strategy, device_mesh)
219
+ return elapsed_time
220
+
221
+ def _model_sharding_cost_from_s_curve(self, strategy,
222
+ device_mesh: LogicalDeviceMesh):
223
+ '''
224
+ [ToDo][KDuan] preprofile the s_curve
225
+ '''
226
+ sharding_cost = 0.0
227
+ return sharding_cost
228
+
229
+ # this method might be overwritten by some Ops
230
+ def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh):
231
+ return 0.0
232
+
233
+ # this method might be overwritten by some Ops
234
+ def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh):
235
+ memory_time = (strategy.inout_memory_footprint /
236
+ device_mesh.cluster_info.memory_bw * 1e-3 *
237
+ device_mesh.cluster_info.memory_efficiency)
238
+ return memory_time
239
+
240
+ def _model_sharding_cost_from_alpha_beta(self, strategy,
241
+ device_mesh: LogicalDeviceMesh):
242
+ math_time = self._get_math_time(strategy, device_mesh)
243
+ mem_time = self._get_memory_time(strategy, device_mesh)
244
+ return max(math_time, mem_time)
245
+
246
+ def _get_communication_cost(self, strategy):
247
+ total_comm_cost = 0.0
248
+ for op_data_name, comm_spec in strategy.communication_actions.items():
249
+ comm_cost = comm_spec.get_comm_cost()
250
+ total_comm_cost = total_comm_cost + comm_cost
251
+ return total_comm_cost
252
+
253
+ def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh):
254
+ self._update_memory_cost(strategies)
255
+
256
+ if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA:
257
+ for strategy in strategies:
258
+ strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta(
259
+ strategy, device_mesh)
260
+ elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE:
261
+ for strategy in strategies:
262
+ strategy.sharding_cost = self._model_sharding_cost_from_s_curve(
263
+ strategy, device_mesh)
264
+ elif device_mesh.config.sharding_cost_model == CostModel.PROFILE:
265
+ for strategy in strategies:
266
+ strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta(
267
+ strategy, device_mesh)
268
+ if self.is_shape_io:
269
+ strategy.sharding_cost = strategy.alpha_beta_cost
270
+ else:
271
+ strategy.sharding_cost = self._profile_sharding_cost(
272
+ strategy, device_mesh)
273
+ elif device_mesh.config.sharding_cost_model == CostModel.ZERO:
274
+ for strategy in strategies:
275
+ strategy.sharding_cost = 0.0
276
+ else:
277
+ assert False, 'unsupport sharding cost model option: {}'.format(
278
+ device_mesh.config.sharding_cost_model)
279
+
280
+ for strategy in strategies:
281
+ strategy.communication_cost = self._get_communication_cost(strategy)
282
+
283
+ def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec,
284
+ op_data):
285
+ transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency(
286
+ pre_sharding_sepc, cur_sharding_spec)
287
+ return (transform_path, comm_action_sequence, resharding_cost)
288
+
289
+ def _update_resharding_cost(self, strategies):
290
+ for strategy in strategies:
291
+ resharding_costs = {}
292
+ for pre_node, out_index in self.predecessor_nodes_out_index.items():
293
+ if pre_node is None:
294
+ continue
295
+ pre_node_out_data_name = pre_node.get_output(out_index).name
296
+ pre_node_out_data_lname = pre_node.global_to_local_op_name[
297
+ pre_node_out_data_name]
298
+ if pre_node_out_data_name not in self.global_to_local_op_name:
299
+ print(f"pre_node_out_data_name = {pre_node_out_data_name}")
300
+ continue
301
+ cur_node_inp_data_lname = self.global_to_local_op_name[
302
+ pre_node_out_data_name]
303
+ cur_sharding_spec = strategy.sharding_specs[
304
+ cur_node_inp_data_lname]
305
+
306
+ pre_node_out_sharding_specs = []
307
+ for pre_strategy in pre_node.strategies_vector:
308
+ pre_node_out_sharding_specs.append(
309
+ pre_strategy.sharding_specs[pre_node_out_data_lname])
310
+
311
+ if pre_node not in resharding_costs:
312
+ resharding_costs[pre_node.node_name] = []
313
+ for prev_sharding_spec in pre_node_out_sharding_specs:
314
+ resharding_cost = self._compute_resharding_cost(
315
+ prev_sharding_spec, cur_sharding_spec,
316
+ self.op_data[cur_node_inp_data_lname])
317
+ resharding_costs[pre_node.node_name].append(resharding_cost)
318
+ strategy.resharding_costs = resharding_costs
319
+
320
+ def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size):
321
+ dim_partition_list = []
322
+ for i in range(dim_size):
323
+ dim_partition_list.append({i: mesh_dim})
324
+ return dim_partition_list
325
+
326
+ def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1,
327
+ dim_size):
328
+ dim_partition_list = []
329
+ for i in range(dim_size):
330
+ for j in range(dim_size):
331
+ if i != j:
332
+ dim_partition_list.append({i: mesh_dim0, j: mesh_dim1})
333
+ return dim_partition_list
334
+
335
+ def _update_memory_cost(self, strategies):
336
+ for strategy in strategies:
337
+ inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0
338
+ for spec in strategy.sharding_specs.values():
339
+ inout_memory_footprint += spec.get_sharded_size_per_device()
340
+ max_inout_memory_footprint += spec.get_max_sharded_size_per_device(
341
+ )
342
+
343
+ # the communication happens
344
+ comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0
345
+ for comm_spec in strategy.communication_actions.values():
346
+ comm_buffer_footprint += comm_spec.get_mem_cost()
347
+ max_comm_buffer_footprint += comm_spec.get_max_mem_cost()
348
+
349
+ # when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time
350
+ # rather than memory usage
351
+ strategy.inout_memory_footprint = inout_memory_footprint
352
+
353
+ strategy.comm_buff_memory_footprint = comm_buffer_footprint
354
+ strategy.peak_memory_footprint = max(max_inout_memory_footprint,
355
+ max_comm_buffer_footprint)
356
+
357
+ # The const memory (weight) is recorded in constant layers and should be accumulated
358
+ strategy.const_memory_footprint = 0.0
359
+
360
+ def _generate_bcast_dims(self, batch_dims, out_data_shape):
361
+ for output in self.outputs:
362
+ if output.broadcast_across_batch:
363
+ for bs in batch_dims:
364
+ if output.shape[
365
+ bs] == 1 and output.shape[bs] != out_data_shape[bs]:
366
+ output.attrs["broadcast_dims"].append(bs)
367
+
368
+ def _recover_bcast_partition_dict(self, partition_dict, op_data):
369
+ ret = {}
370
+ for data_dim, mesh_dim in partition_dict.items():
371
+ if data_dim not in op_data.attrs[
372
+ "broadcast_dims"] and data_dim + len(
373
+ op_data.shape) not in op_data.attrs[
374
+ "broadcast_dims"] and op_data.shape[data_dim] != 1:
375
+ ret[data_dim] = mesh_dim
376
+ return ret
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class Normalization(Node):
6
+
7
+ def __init__(self, layer):
8
+ super().__init__(layer)
9
+ layer.to_subclass()
10
+ self.axes = layer.as_trt().axes
11
+ self.weight_bias_dim_base = 0
12
+ layer.to_base_class()
13
+
14
+ def _collect_strategies(self, device_mesh):
15
+ dim_partition_list = []
16
+ dim_size = len(self.op_data['input0'].shape)
17
+ dim_partition_list.append({})
18
+ dim_partition_list.extend(
19
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
20
+ dim_partition_list.extend(
21
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
24
+ dim_partition_list.extend(
25
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
26
+ strategies_vector = StrategiesVector(self)
27
+ for dim_partition_dict in dim_partition_list:
28
+ shard_reduction_axes = False
29
+ for dim in range(len(self.get_input(0).shape)):
30
+ if (self.axes & (1 << dim)) and dim in dim_partition_dict:
31
+ shard_reduction_axes = True
32
+ break
33
+ if shard_reduction_axes:
34
+ continue
35
+ dim_partition_dict_mapping = {
36
+ "input0": dim_partition_dict,
37
+ "output0": dim_partition_dict,
38
+ }
39
+ if self.num_inputs >= 2:
40
+ dim_partition_dict_mapping['input1'] = {}
41
+ if self.num_inputs >= 3:
42
+ dim_partition_dict_mapping['input2'] = {}
43
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
44
+ dim_partition_dict_mapping, device_mesh)
45
+ if 0 == len(sharding_spec_mapping):
46
+ continue
47
+ name = '{} = {} <normalization op> scale {}, bias {}'.format(
48
+ sharding_spec_mapping['output0'].sharding_sequence,
49
+ sharding_spec_mapping['input0'].sharding_sequence,
50
+ sharding_spec_mapping['input1'].sharding_sequence
51
+ if self.num_inputs >= 2 else 'None',
52
+ sharding_spec_mapping['input2'].sharding_sequence
53
+ if self.num_inputs >= 3 else 'None',
54
+ )
55
+ sharding_strategy = self._get_sharding_strategy(
56
+ name=name,
57
+ sharding_spec_mapping=sharding_spec_mapping,
58
+ communication_action_mapping={})
59
+ strategies_vector.append(sharding_strategy)
60
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class OuputNode(Node):
6
+
7
+ def _update_memory_cost(self, strategies):
8
+ for strategy in strategies:
9
+ if not self.no_memory_footprint:
10
+ strategy.const_memory_footprint = strategy.sharding_specs[
11
+ 'input0'].get_max_sharded_size_per_device()
12
+
13
+ def __init__(self, tensor):
14
+ self._layer = None
15
+ self.is_shape_io = False
16
+ self._inputs = []
17
+ self._outputs = []
18
+ self.predecessor_nodes = []
19
+ self.predecessor_nodes_out_index = {}
20
+ self.successor_nodes = []
21
+ self.op_data = {}
22
+ self.global_to_local_op_name = {}
23
+ self.is_replicated = tensor.attrs.get("is_replicated", False)
24
+ self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
25
+ self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
26
+ False)
27
+ self.building_block_id = -1
28
+ self.cost_level = -1
29
+ self.stage_type = None
30
+ self.in_start_block = None
31
+ self.in_end_block = None
32
+ self.in_slowest_block = None
33
+ input = tensor.copy()
34
+ self._inputs.append(input)
35
+ self.op_data['input0'] = input
36
+ self.global_to_local_op_name[input.name] = 'input0'
37
+
38
+ self.sharding_weight = 1.0
39
+ self.resharding_weight = 1.0
40
+ self.pipeline_weight = 0
41
+ self.node_name = tensor.name
42
+ self.node_type = 'output_node'
43
+ self.num_inputs = 0
44
+ self.num_outputs = 1
45
+ self.dtype = tensor.dtype
46
+ self.strategies_vector = []
47
+ self.node_runtime_profiler = None
48
+
49
+ def _collect_strategies(self, device_mesh):
50
+ dim_partition_list = []
51
+ dim_size = len(self.op_data['input0'].shape)
52
+ dim_partition_list.append({})
53
+ dim_partition_list.extend(
54
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
55
+ dim_partition_list.extend(
56
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
57
+ dim_partition_list.extend(
58
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
59
+ dim_partition_list.extend(
60
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
61
+
62
+ strategies_vector = StrategiesVector(self)
63
+ for dim_partition_dict in dim_partition_list:
64
+ dim_partition_dict_mapping = {'input0': dim_partition_dict}
65
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
66
+ dim_partition_dict_mapping, device_mesh)
67
+ if 0 == len(sharding_spec_mapping):
68
+ continue
69
+ sharding_seq = sharding_spec_mapping['input0'].sharding_sequence
70
+ sharding_strategy = self._get_sharding_strategy(
71
+ name=f'output-op {sharding_seq}',
72
+ sharding_spec_mapping=sharding_spec_mapping,
73
+ communication_action_mapping={})
74
+ strategies_vector.append(sharding_strategy)
75
+
76
+ return strategies_vector
77
+
78
+ def _profile_sharding_cost(self, strategy, device_mesh):
79
+ return 0.0
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from enum import Enum
3
+
4
+ from .comm_spec import CommSpec
5
+ from .identity_node import Identity
6
+ from .sharding_strategy import StrategiesVector
7
+
8
+
9
+ class P2PType(Enum):
10
+ CROSS_DEVICE = 0
11
+ CROSS_HOST = 1
12
+
13
+
14
+ class P2PNode(Identity):
15
+
16
+ def __init__(self, layer):
17
+ super().__init__(layer)
18
+ self.p2p_type = layer.attrs["p2p_type"]
19
+ self.is_fake = True
20
+
21
+ def _collect_strategies(self, device_mesh):
22
+ # one input for softmax node
23
+ predecessor = self.predecessor_nodes[0]
24
+ strategies_vector = StrategiesVector(self)
25
+ for idx, strategy in enumerate(predecessor.strategies_vector):
26
+ # current node's local name input0 -> global name xxx
27
+ global_input_name = self.op_data['input0'].name
28
+ # global name xxx -> pre node local output name
29
+ prenode_local_name = predecessor.global_to_local_op_name[
30
+ global_input_name]
31
+ dim_partition_dict = copy.deepcopy(
32
+ strategy.sharding_specs[prenode_local_name].dim_partition_dict)
33
+ in0_partition_dict = dim_partition_dict
34
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
35
+ dim_partition_dict_mapping = {
36
+ "input0": in0_partition_dict,
37
+ "output0": out_partition_dict,
38
+ }
39
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
40
+ dim_partition_dict_mapping, device_mesh)
41
+ if 0 == len(sharding_spec_mapping):
42
+ continue
43
+
44
+ logical_process_axis = [
45
+ ['p2p_cross_device']
46
+ ] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']]
47
+ # get communication action mapping
48
+ communication_action_mapping = {}
49
+ output0_comm_action = CommSpec(
50
+ comm_pattern='peer_to_peer',
51
+ sharding_spec=sharding_spec_mapping['output0'],
52
+ logical_process_axis=logical_process_axis,
53
+ )
54
+ communication_action_mapping['output0'] = output0_comm_action
55
+
56
+ name = '{} = <P2P op> {}'.format(
57
+ sharding_spec_mapping['output0'].sharding_sequence,
58
+ sharding_spec_mapping['input0'].sharding_sequence)
59
+ sharding_strategy = self._get_sharding_strategy(
60
+ name=name,
61
+ sharding_spec_mapping=sharding_spec_mapping,
62
+ communication_action_mapping=communication_action_mapping)
63
+ strategies_vector.append(sharding_strategy)
64
+ return strategies_vector
65
+
66
+ def _profile_sharding_cost(self, strategy, device_mesh):
67
+ return 0.0
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorrt_llm.network import PluginInfo, get_plugin_info
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class PluginNode(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ self.plugin = layer.as_trt().plugin
13
+ self.plugin_type: str = self.plugin.plugin_type
14
+ self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(),
15
+ layer.name)
16
+ layer.to_base_class()
17
+
18
+ def _collect_strategies(self, device_mesh):
19
+ raise NotImplementedError(
20
+ f"Auto parallel does not support {self.plugin_type} plugin right now."
21
+ )
22
+
23
+ def _default_strategy(self, device_mesh):
24
+ strategies_vector = StrategiesVector(self)
25
+ dim_partition_dict_mapping = {}
26
+ for idx in range(self.num_inputs):
27
+ dim_partition_dict_mapping[f'input{idx}'] = {}
28
+ for idx in range(self.num_outputs):
29
+ dim_partition_dict_mapping[f'output{idx}'] = {}
30
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
31
+ dim_partition_dict_mapping, device_mesh)
32
+ if 0 == len(sharding_spec_mapping):
33
+ return strategies_vector
34
+ name = '{}_all_replicate'.format(self.plugin_type)
35
+ sharding_strategy = self._get_sharding_strategy(
36
+ name=name,
37
+ sharding_spec_mapping=sharding_spec_mapping,
38
+ communication_action_mapping={})
39
+ strategies_vector.append(sharding_strategy)
40
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py ADDED
File without changes
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+
3
+ from tensorrt_llm._utils import trt_dtype_to_str
4
+
5
+ from ..matmul_node import MatrixMultiply
6
+ from ..plugin_node import PluginNode
7
+
8
+
9
+ class GemmPlugin(MatrixMultiply, PluginNode):
10
+
11
+ def __init__(self, layer):
12
+ PluginNode.__init__(self, layer)
13
+ batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
14
+ self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
15
+ pfc_as_list = self.plugin_info.pfc_as_list
16
+ self.op0_transpose = (pfc_as_list['transa'][0] == 1)
17
+ self.op1_transpose = (pfc_as_list['transb'][0] == 1)
18
+ self.num_out_dims = len(self.get_output(0).shape)
19
+ self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0]))
20
+
21
+ def _collect_strategies(self, device_mesh):
22
+ strategies_vector = MatrixMultiply._collect_strategies(
23
+ self, device_mesh)
24
+ return strategies_vector
25
+
26
+ def _get_math_time(self, strategy, device_mesh):
27
+ return MatrixMultiply._get_math_time(self, strategy, device_mesh)
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, auto
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from tensorrt_llm.functional import PositionEmbeddingType
7
+ from tensorrt_llm.quantization import QuantMode
8
+
9
+ from ..plugin_node import PluginNode
10
+ from ..sharding_strategy import StrategiesVector
11
+
12
+
13
+ # WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
14
+ class IdxEntry(Enum):
15
+ QKV_TENSOR = auto()
16
+ K_TENSOR = auto()
17
+ V_TENSOR = auto()
18
+ SEQUENCE_LENGTH = auto()
19
+ HOST_PAST_KEY_VALUE_LENGTHS = auto()
20
+ HOST_MAX_ATTENTION_WINDOW = auto()
21
+ HOST_SINK_TOKEN_LENGTH = auto()
22
+ CONTEXT_LENGTHS = auto()
23
+ CACHE_INDIR = auto()
24
+ REQUEST_TYPES = auto()
25
+ KV_CACHE_BLOCK_OFFSETS = auto()
26
+ HOST_KV_CACHE_BLOCK_OFFSETS = auto()
27
+ HOST_KV_CACHE_POOL_POINTERS = auto()
28
+ PAST_KEY_VALUE = auto()
29
+ KV_CACHE_QUANTIZATION_SCALE = auto()
30
+ KV_CACHE_DEQUANTIZATION_SCALE = auto()
31
+ ROTARY_INV_FREQ = auto()
32
+ ROTARY_COS_SIN = auto()
33
+ ALIBI_SLOPES = auto()
34
+ RELATIVE_ATTENTION_BIAS = auto()
35
+ CROSS_QKV = auto()
36
+ CROSS_QKV_LENGTH = auto()
37
+ ENCODER_INPUT_LENGTH = auto()
38
+ HOST_CONTEXT_LENGTH = auto()
39
+ QKV_BIAS_TENSOR = auto()
40
+ SPEC_DECODING_PACKED_MASK = auto()
41
+ SPEC_DECODING_POSITION_OFFSETS = auto()
42
+ SPEC_DECODING_GENERATION_LENGTHS = auto()
43
+ HOST_RUNTIME_PERF_KNOBS = auto()
44
+
45
+
46
+ class IdxEntryParser:
47
+
48
+ def __init__(self, plugin_info):
49
+ self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0]
50
+ self.unfuse_qkv_gemm = bool(
51
+ plugin_info.pfc_as_list['unfuse_qkv_gemm'][0])
52
+ self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0])
53
+ self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0])
54
+ self.do_cross_attention = bool(
55
+ plugin_info.pfc_as_list['do_cross_attention'][0])
56
+ self.remove_input_padding = bool(
57
+ plugin_info.pfc_as_list['remove_input_padding'][0])
58
+ self.qkv_bias_enabled = bool(
59
+ plugin_info.pfc_as_list['qkv_bias_enabled'][0])
60
+ self.kv_cache_quant_mode = QuantMode(
61
+ plugin_info.pfc_as_list['kv_cache_quant_mode'][0])
62
+ self.position_embedding_type = PositionEmbeddingType(
63
+ plugin_info.pfc_as_list['position_embedding_type'][0])
64
+ self.is_spec_decoding_enabled = bool(
65
+ plugin_info.pfc_as_list['is_spec_decoding_enabled'][0])
66
+ self.init_entry_to_index()
67
+
68
+ # WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp
69
+ def is_entry_used(self, entry: IdxEntry) -> bool:
70
+ if entry == IdxEntry.QKV_TENSOR:
71
+ return True
72
+ elif entry == IdxEntry.K_TENSOR:
73
+ return self.unfuse_qkv_gemm
74
+ elif entry == IdxEntry.V_TENSOR:
75
+ return self.unfuse_qkv_gemm
76
+ elif entry == IdxEntry.SEQUENCE_LENGTH:
77
+ return self.use_cache
78
+ elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS:
79
+ return self.use_cache
80
+ elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW:
81
+ return True
82
+ elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH:
83
+ return True
84
+ elif entry == IdxEntry.CONTEXT_LENGTHS:
85
+ return True
86
+ elif entry == IdxEntry.CACHE_INDIR:
87
+ return self.use_cache
88
+ elif entry == IdxEntry.REQUEST_TYPES:
89
+ return True
90
+ elif entry == IdxEntry.KV_CACHE_BLOCK_OFFSETS:
91
+ return self.use_cache and self.paged_kv_cache
92
+ elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_OFFSETS:
93
+ return self.use_cache and self.paged_kv_cache
94
+ elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS:
95
+ return self.use_cache and self.paged_kv_cache
96
+ elif entry == IdxEntry.PAST_KEY_VALUE:
97
+ return self.use_cache and not self.paged_kv_cache
98
+ elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE:
99
+ return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
100
+ )
101
+ elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE:
102
+ return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
103
+ )
104
+ elif entry == IdxEntry.ROTARY_INV_FREQ:
105
+ return self.position_embedding_type.is_rope()
106
+ elif entry == IdxEntry.ROTARY_COS_SIN:
107
+ return self.position_embedding_type.is_rope()
108
+ elif entry == IdxEntry.ALIBI_SLOPES:
109
+ return self.position_embedding_type.is_alibi()
110
+ elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS:
111
+ return self.position_embedding_type == PositionEmbeddingType.relative
112
+ elif entry == IdxEntry.CROSS_QKV:
113
+ return self.do_cross_attention
114
+ elif entry == IdxEntry.CROSS_QKV_LENGTH:
115
+ return self.do_cross_attention
116
+ elif entry == IdxEntry.ENCODER_INPUT_LENGTH:
117
+ return self.do_cross_attention
118
+ elif entry == IdxEntry.HOST_CONTEXT_LENGTH:
119
+ return self.remove_input_padding
120
+ elif entry == IdxEntry.QKV_BIAS_TENSOR:
121
+ return self.qkv_bias_enabled
122
+ elif entry == IdxEntry.SPEC_DECODING_PACKED_MASK:
123
+ return self.is_spec_decoding_enabled
124
+ elif entry == IdxEntry.SPEC_DECODING_POSITION_OFFSETS:
125
+ return self.is_spec_decoding_enabled
126
+ elif entry == IdxEntry.SPEC_DECODING_GENERATION_LENGTHS:
127
+ return self.is_spec_decoding_enabled
128
+ elif entry == IdxEntry.HOST_RUNTIME_PERF_KNOBS:
129
+ return True
130
+ else:
131
+ return False
132
+
133
+ def init_entry_to_index(self):
134
+ self.entry_to_index = {}
135
+ index = 0
136
+ for entry in IdxEntry:
137
+ if self.is_entry_used(entry):
138
+ self.entry_to_index[entry] = index
139
+ index += 1
140
+
141
+ def get_index(self, entry: IdxEntry) -> int:
142
+ if entry not in self.entry_to_index:
143
+ raise Exception(
144
+ f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}"
145
+ )
146
+ return self.entry_to_index[entry]
147
+
148
+
149
+ def get_partition(device_dim, device_ids):
150
+ if device_dim == [0]:
151
+ partition = device_ids.shape[0]
152
+ elif device_dim == [1]:
153
+ partition = device_ids.shape[1]
154
+ else:
155
+ assert device_dim == [0, 1] or device_dim == [1, 0]
156
+ partition = device_ids.size
157
+ return partition
158
+
159
+
160
+ class GPTAttentionPlugin(PluginNode):
161
+
162
+ def __init__(self, layer):
163
+ super().__init__(layer)
164
+ self.parser = IdxEntryParser(self.plugin_info)
165
+ assert self.num_inputs == len(
166
+ self.parser.entry_to_index
167
+ ), f'the number of plugin inputs ({self.num_inputs}) is invalid'
168
+ assert self.num_outputs == (
169
+ 2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else 1
170
+ ), f'the number of plugin outputs ({self.num_outputs}) has been changed'
171
+
172
+ def _tp_strategy(self, device_mesh):
173
+ strategies_vector = StrategiesVector(self)
174
+ head_dim = 1 if self.parser.remove_input_padding else 2
175
+ # TODO: allow mesh_dim = [0] or [1]
176
+ # for mesh_dim in ([0], [1], [0, 1]):
177
+ for mesh_dim in ([0, 1], ):
178
+ if self.parser.num_kv_heads != 1:
179
+ # MHA or GQA
180
+ # TODO: allow to duplicate kv when #kv_head < #partition
181
+ q_pdict = {
182
+ head_dim: mesh_dim
183
+ } # split in heads/hidden dimension
184
+ k_pdict = {
185
+ head_dim: mesh_dim
186
+ } # split in heads/hidden dimension
187
+ v_pdict = {
188
+ head_dim: mesh_dim
189
+ } # split in heads/hidden dimension
190
+ pastkv_pdict = {2: mesh_dim} # split in heads dimension
191
+ present_kv_pdict = {2: mesh_dim} # split in heads dimension
192
+ else:
193
+ # MQA
194
+ q_pdict = {
195
+ head_dim: mesh_dim
196
+ } # split in heads/hidden dimension
197
+ k_pdict = {} # RR
198
+ v_pdict = {} # RR
199
+ pastkv_pdict = {} # RR
200
+ present_kv_pdict = {} # RR
201
+
202
+ out0_pdict = {head_dim: mesh_dim}
203
+
204
+ dim_partition_dict_mapping = {
205
+ f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict,
206
+ f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict,
207
+ f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict,
208
+ 'output0': out0_pdict,
209
+ }
210
+ if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE):
211
+ dim_partition_dict_mapping[
212
+ f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict
213
+ dim_partition_dict_mapping['output1'] = present_kv_pdict
214
+ for i in range(self.num_inputs):
215
+ if f'input{i}' not in dim_partition_dict_mapping:
216
+ dim_partition_dict_mapping[f'input{i}'] = {}
217
+ for i in range(self.num_outputs):
218
+ if f'output{i}' not in dim_partition_dict_mapping:
219
+ dim_partition_dict_mapping[f'output{i}'] = {}
220
+
221
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
222
+ dim_partition_dict_mapping, device_mesh)
223
+ if 0 == len(sharding_spec_mapping):
224
+ continue
225
+ name = 'gptAttentionPlugin_tp_strategy'
226
+ sharding_strategy = self._get_sharding_strategy(
227
+ name=name,
228
+ sharding_spec_mapping=sharding_spec_mapping,
229
+ communication_action_mapping={})
230
+ strategies_vector.append(sharding_strategy)
231
+ return strategies_vector
232
+
233
+ def _dp_strategy(self, device_mesh):
234
+ strategies_vector = StrategiesVector(self)
235
+ for mesh_dim in ([0], [1], [0, 1]):
236
+ dim_partition_dict_mapping = {}
237
+ for i in range(self.num_inputs):
238
+ dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim}
239
+ for i in range(self.num_outputs):
240
+ dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim}
241
+
242
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
243
+ dim_partition_dict_mapping, device_mesh)
244
+ if 0 == len(sharding_spec_mapping):
245
+ continue
246
+ name = 'gptAttentionPlugin_dp_strategy'
247
+ sharding_strategy = self._get_sharding_strategy(
248
+ name=name,
249
+ sharding_spec_mapping=sharding_spec_mapping,
250
+ communication_action_mapping={})
251
+ strategies_vector.append(sharding_strategy)
252
+ return strategies_vector
253
+
254
+ def _collect_strategies(self, device_mesh):
255
+ if device_mesh.size == 1:
256
+ default_strategies = self._default_strategy(device_mesh)
257
+ else:
258
+ # Avoid to use all-replicate strategy for mesh size > 1
259
+ # since the CPP runtime does not support it for gpt attention plugin
260
+ default_strategies = StrategiesVector(self)
261
+ for idx, strategy in enumerate(default_strategies):
262
+ strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}'
263
+ if self.parser.unfuse_qkv_gemm:
264
+ tp_strategies = self._tp_strategy(device_mesh)
265
+ default_strategies.extend(tp_strategies)
266
+ # if we don't split the batch dim, it should be default strategis
267
+ # elif we split the batch dim, it should be dp_strategies
268
+ # we can use above information to distinguish the two kinds of strategy
269
+ if not self.parser.remove_input_padding:
270
+ dp_strategies = self._dp_strategy(device_mesh)
271
+ default_strategies.extend(dp_strategies)
272
+ return default_strategies
273
+
274
+ @staticmethod
275
+ def parameter_generator(sharding_specs, plugin_info):
276
+
277
+ def get_shape(entry):
278
+ return sharding_specs[
279
+ f'input{parser.get_index(entry)}'].get_sharded_shape_per_device(
280
+ )
281
+
282
+ parser = IdxEntryParser(plugin_info)
283
+ updated_input_values = {}
284
+ batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0]
285
+ if parser.use_cache:
286
+ beams_width = get_shape(IdxEntry.CACHE_INDIR)[1]
287
+ max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2]
288
+ elif not parser.remove_input_padding:
289
+ max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1]
290
+ else:
291
+ max_seq_length = 1
292
+ host_request_types = torch.full(
293
+ (batch_size, ),
294
+ 1,
295
+ dtype=torch.int32,
296
+ device='cpu',
297
+ )
298
+ updated_input_values[parser.get_index(
299
+ IdxEntry.REQUEST_TYPES)] = host_request_types
300
+ context_lengths = torch.full(
301
+ (batch_size, ),
302
+ max_seq_length - 1,
303
+ dtype=torch.int32,
304
+ device=torch.cuda.current_device(),
305
+ )
306
+ updated_input_values[parser.get_index(
307
+ IdxEntry.CONTEXT_LENGTHS)] = context_lengths
308
+ host_max_attention_window_sizes = torch.tensor(
309
+ [max_seq_length],
310
+ dtype=torch.int32,
311
+ device='cpu',
312
+ )
313
+ updated_input_values[parser.get_index(
314
+ IdxEntry.HOST_MAX_ATTENTION_WINDOW
315
+ )] = host_max_attention_window_sizes
316
+ host_sink_token_length = torch.tensor(
317
+ [0],
318
+ dtype=torch.int32,
319
+ device='cpu',
320
+ )
321
+ updated_input_values[parser.get_index(
322
+ IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length
323
+ if parser.use_cache:
324
+ sequence_length = torch.full((batch_size, ),
325
+ max_seq_length,
326
+ dtype=torch.int32,
327
+ device=torch.cuda.current_device())
328
+ updated_input_values[parser.get_index(
329
+ IdxEntry.SEQUENCE_LENGTH)] = sequence_length
330
+ host_past_key_value_length = torch.full((batch_size, ),
331
+ max_seq_length - 1,
332
+ dtype=torch.int32,
333
+ device='cpu')
334
+ updated_input_values[parser.get_index(
335
+ IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS
336
+ )] = host_past_key_value_length
337
+ cache_indirections = torch.full(
338
+ (batch_size, beams_width, max_seq_length),
339
+ 0,
340
+ dtype=torch.int32,
341
+ device=torch.cuda.current_device())
342
+ updated_input_values[parser.get_index(
343
+ IdxEntry.CACHE_INDIR)] = cache_indirections
344
+ if parser.remove_input_padding:
345
+ host_context_lengths = torch.full(get_shape(
346
+ IdxEntry.HOST_CONTEXT_LENGTH),
347
+ max_seq_length - 1,
348
+ dtype=torch.int32,
349
+ device='cpu')
350
+ updated_input_values[parser.get_index(
351
+ IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths
352
+ return updated_input_values
353
+
354
+ def _profile_sharding_cost(self, strategy, device_mesh):
355
+ sharding_spec = strategy.sharding_specs[
356
+ f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"]
357
+ shard_dims = sharding_spec.dim_partition_dict
358
+ device_ids = device_mesh.phy_ids
359
+ if 2 in shard_dims:
360
+ device_dim = shard_dims[2]
361
+ partition = get_partition(device_dim, device_ids)
362
+ else:
363
+ partition = 1
364
+ if self.parser.is_entry_used(IdxEntry.K_TENSOR):
365
+ kv_sharding_spec = strategy.sharding_specs[
366
+ f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"]
367
+ kv_shard_dims = kv_sharding_spec.dim_partition_dict
368
+ if 2 in kv_shard_dims:
369
+ kv_device_dim = kv_shard_dims[2]
370
+ kv_partition = get_partition(kv_device_dim, device_ids)
371
+ else:
372
+ kv_partition = 1
373
+ else:
374
+ kv_partition = 1
375
+ num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy()
376
+ num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
377
+ tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy()
378
+ tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy()
379
+ num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
380
+ num_heads = np.maximum(num_heads // partition, 1)
381
+ tp_size[0] = partition
382
+ tp_rank[0] = 0
383
+
384
+ updated_layer_attrs = {
385
+ 'tp_size': tp_size,
386
+ 'tp_rank': tp_rank,
387
+ 'num_heads': num_heads,
388
+ 'num_kv_heads': num_kv_heads
389
+ }
390
+ updated_input_values = self.parameter_generator(strategy.sharding_specs,
391
+ self.plugin_info)
392
+ elapsed_time = self.node_runtime_profiler.runtime_profile(
393
+ self.layer, updated_layer_attrs, updated_input_values, strategy,
394
+ device_mesh)
395
+ return elapsed_time
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..identity_node import Identity
2
+ from ..plugin_node import PluginNode
3
+
4
+
5
+ class IdentityPlugin(Identity, PluginNode):
6
+
7
+ def __init__(self, layer):
8
+ PluginNode.__init__(self, layer)
9
+
10
+ def _collect_strategies(self, device_mesh):
11
+ return Identity._collect_strategies(self, device_mesh)
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+
3
+ from ..gather_node import Gather
4
+ from ..plugin_node import PluginNode
5
+
6
+
7
+ class LookupPlugin(Gather, PluginNode):
8
+
9
+ def __init__(self, layer):
10
+ PluginNode.__init__(self, layer)
11
+ self.mode = trt.GatherMode.DEFAULT
12
+ self.axis = 0
13
+ self.num_elementwise_dims = 0
14
+ self.input_id = 1
15
+ self.indice_id = 0
16
+ self.support_vocab_tp = True
17
+
18
+ def _collect_strategies(self, device_mesh):
19
+ return Gather._collect_strategies(self, device_mesh)
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..normalization_node import Normalization
2
+ from ..plugin_node import PluginNode
3
+
4
+
5
+ class LayernormPlugin(Normalization, PluginNode):
6
+
7
+ def __init__(self, layer):
8
+ PluginNode.__init__(self, layer)
9
+ # the is only true for llm model, because layer norm is only effect on hidden dim
10
+ hidden_dim = len(self.op_data['input0'].shape) - 1
11
+ self.axes = 1 << hidden_dim
12
+ self.weight_bias_dim_base = hidden_dim
13
+
14
+ def _collect_strategies(self, device_mesh):
15
+ return Normalization._collect_strategies(self, device_mesh)
16
+
17
+
18
+ class RMSnormPlugin(Normalization, PluginNode):
19
+
20
+ def __init__(self, layer):
21
+ PluginNode.__init__(self, layer)
22
+ # the is only true for llm model, because rms norm is only effect on hidden dim
23
+ hidden_dim = len(self.op_data['input0'].shape) - 1
24
+ self.axes = 1 << hidden_dim
25
+ self.weight_bias_dim_base = hidden_dim
26
+
27
+ def _collect_strategies(self, device_mesh):
28
+ return Normalization._collect_strategies(self, device_mesh)
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorrt_llm._utils import trt_axes_to_dim
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Reduce(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes)
13
+ self.sum_mapping_dict = {}
14
+ num_input_dims = len(self.get_input(0).shape)
15
+ if layer.as_trt().keep_dims:
16
+ for i in range(num_input_dims):
17
+ self.sum_mapping_dict[i] = i
18
+ else:
19
+ output_index = 0
20
+ for i in range(num_input_dims):
21
+ if i not in self.reduce_dims:
22
+ self.sum_mapping_dict[i] = output_index
23
+ output_index += 1
24
+ assert output_index == len(self.get_output(0).shape)
25
+ layer.to_base_class()
26
+
27
+ def _collect_strategies(self, device_mesh):
28
+ dim_partition_list = []
29
+ dim_size = len(self.op_data['input0'].shape)
30
+ dim_partition_list.append({})
31
+ dim_partition_list.extend(
32
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
33
+ dim_partition_list.extend(
34
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
35
+ dim_partition_list.extend(
36
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
37
+ dim_partition_list.extend(
38
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
39
+ strategies_vector = StrategiesVector(self)
40
+ for dim_partition_dict in dim_partition_list:
41
+ recover_dims = []
42
+ out_partition_dict = {}
43
+ for dim in dim_partition_dict.keys():
44
+ if dim in self.reduce_dims:
45
+ recover_dims.append(dim)
46
+ elif dim in self.sum_mapping_dict:
47
+ out_partition_dict[
48
+ self.sum_mapping_dict[dim]] = dim_partition_dict[dim]
49
+ else:
50
+ assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict'
51
+
52
+ for dim in recover_dims:
53
+ dim_partition_dict.pop(dim)
54
+
55
+ in0_parition_dict = dim_partition_dict
56
+ dim_partition_dict_mapping = {
57
+ "input0": in0_parition_dict,
58
+ "output0": out_partition_dict,
59
+ }
60
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
61
+ dim_partition_dict_mapping, device_mesh)
62
+ if 0 == len(sharding_spec_mapping):
63
+ continue
64
+ name = '{} = <reduce along dim {}> {}'.format(
65
+ sharding_spec_mapping['output0'].sharding_sequence,
66
+ self.reduce_dims,
67
+ sharding_spec_mapping['input0'].sharding_sequence)
68
+ sharding_strategy = self._get_sharding_strategy(
69
+ name=name,
70
+ sharding_spec_mapping=sharding_spec_mapping,
71
+ communication_action_mapping={})
72
+ strategies_vector.append(sharding_strategy)
73
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .node import Node
2
+ from .sharding_strategy import StrategiesVector
3
+
4
+
5
+ class Select(Node):
6
+
7
+ def __init__(self, layer):
8
+ super().__init__(layer)
9
+ batch_dims = [i for i in range(len(self.get_output(0).shape))]
10
+ self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
11
+
12
+ def _collect_strategies(self, device_mesh):
13
+ dim_partition_list = []
14
+ dim_size = len(self.op_data['output0'].shape)
15
+ dim_partition_list.append({})
16
+ dim_partition_list.extend(
17
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
18
+ dim_partition_list.extend(
19
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
20
+ dim_partition_list.extend(
21
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
22
+ dim_partition_list.extend(
23
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
24
+
25
+ strategies_vector = StrategiesVector(self)
26
+ for dim_partition_dict in dim_partition_list:
27
+ # the three inputs are condition, true tensor and false tensor
28
+ in0_partition_dict = self._recover_bcast_partition_dict(
29
+ dim_partition_dict, self.op_data['input0'])
30
+ in1_partition_dict = self._recover_bcast_partition_dict(
31
+ dim_partition_dict, self.op_data['input1'])
32
+ in2_partition_dict = self._recover_bcast_partition_dict(
33
+ dim_partition_dict, self.op_data['input2'])
34
+ out_partition_dict = dim_partition_dict
35
+ dim_partition_dict_mapping = {
36
+ "input0": in0_partition_dict,
37
+ "input1": in1_partition_dict,
38
+ "input2": in2_partition_dict,
39
+ "output0": out_partition_dict,
40
+ }
41
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
42
+ dim_partition_dict_mapping, device_mesh)
43
+ if 0 == len(sharding_spec_mapping):
44
+ continue
45
+ name = '{} = <select op {}> {} {}'.format(
46
+ sharding_spec_mapping['output0'].sharding_sequence,
47
+ sharding_spec_mapping['input0'].sharding_sequence,
48
+ sharding_spec_mapping['input1'].sharding_sequence,
49
+ sharding_spec_mapping['input2'].sharding_sequence)
50
+
51
+ sharding_strategy = self._get_sharding_strategy(
52
+ name=name,
53
+ sharding_spec_mapping=sharding_spec_mapping,
54
+ communication_action_mapping={})
55
+ strategies_vector.append(sharding_strategy)
56
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_consistency.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import operator
3
+ from functools import reduce
4
+ from typing import List, Tuple
5
+
6
+ import pandas as pd
7
+
8
+ from .comm_spec import CommSpec
9
+ from .sharding_spec import ShardingSpec
10
+
11
+
12
+ class ShapeConsistencyManager(object):
13
+
14
+ def __init__(self):
15
+ self.forward_only = True
16
+ self.cached_spec_pairs_transform_path = {}
17
+ self.cache_hit = 0
18
+ self.cache_miss = 0
19
+
20
+ def all_gather_simulator(self, target_pair):
21
+ _, shard_list = target_pair
22
+ new_shard_list = []
23
+ return new_shard_list
24
+
25
+ def all_to_all_simulator(self, f_target_pair, b_target_pair):
26
+ '''
27
+ Simulating all-to-all operation, analyze the communication cost
28
+ and simulate the influence of the DimSpec.
29
+
30
+ We BANNED all representations which shard_list in decreasing order,
31
+ such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
32
+ Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
33
+ Argument:
34
+ target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
35
+ and the second element describes which logical axis will be sharded in that dimension.
36
+ e.g.:
37
+ all-to-all(S0, S1) -> [S01, R]
38
+ all-to-all(S0, R) -> [R, S0]
39
+ Otherwise, we extend the front shard_list to behind.
40
+ e.g.:
41
+ all-to-all(R, S1) -> [S1, R]
42
+
43
+ Argument:
44
+ target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
45
+ and the second element describes which logical axis will be sharded in that dimension.
46
+ '''
47
+ _, f_shard_list = f_target_pair
48
+ _, b_shard_list = b_target_pair
49
+ if not len(b_shard_list):
50
+ b_shard_list.extend(f_shard_list)
51
+ f_shard_list = []
52
+ else:
53
+ f_shard_list.extend(b_shard_list)
54
+ b_shard_list = []
55
+
56
+ return f_shard_list, b_shard_list
57
+
58
+ def shard_simulator(self, target_pair, legal_sharding_dims):
59
+ '''
60
+ Simulating shard operation, analyze the communication cost(always ZERO)
61
+ and simulate the influence of the DimSpec.
62
+
63
+ We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
64
+ In addition, We BANNED all representations which shard_list in decreasing order,
65
+ such as S10, so shard(S0) -> S10 is NOT allowed.
66
+ Therefore, for the R dimension, we could just append any legal sharding dim on it.
67
+ e.g.:
68
+ shard(R) -> S0
69
+ For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
70
+ e.g:
71
+ shard(S0) -> S01
72
+
73
+ Argument:
74
+ target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
75
+ and the second element describes which logical axis will be sharded in that dimension.
76
+ '''
77
+ _, shard_list = target_pair
78
+ shard_list_list, logical_process_axis = [], []
79
+ for dim in legal_sharding_dims:
80
+ if len(shard_list) != 0 and dim <= shard_list[-1]:
81
+ continue
82
+ new_shard_list = shard_list + [dim]
83
+ shard_list_list.append(new_shard_list)
84
+ logical_process_axis.append([dim])
85
+
86
+ # we support sorted 2D mesh here
87
+ if len(legal_sharding_dims) == 2 and len(shard_list) == 0:
88
+ shard_list_list.append(legal_sharding_dims)
89
+ logical_process_axis.append(legal_sharding_dims)
90
+ return shard_list_list, logical_process_axis
91
+
92
+ def mix_gather_simulator(self, f_target_pair, b_target_pair):
93
+ '''
94
+ Assume index of f and b target pairs are 'f' and 'b'
95
+ S0S1 => Input: (f, [0]), (b, [1]) Output: [f, b], [[0], [1]]
96
+ S1S0 => Input: (f, [1]), (b, [0]) Output: [f, b], [[1], [0]]
97
+ S01R => Input: (f, [0, 1]), (b, []) Output: [f], [[0, 1]]
98
+ RS01 => Input: (f, []), (b, [0, 1]) Output: [b], [[0, 1]]
99
+ '''
100
+ if f_target_pair[1] and b_target_pair[1]:
101
+ return [f_target_pair[0],
102
+ b_target_pair[0]], [f_target_pair[1], b_target_pair[1]]
103
+ if f_target_pair[1]:
104
+ return [f_target_pair[0]], [f_target_pair[1]]
105
+ if b_target_pair[1]:
106
+ return [b_target_pair[0]], [b_target_pair[1]]
107
+
108
+ def get_all_all_gather_spec(self, source_spec, orig_cost):
109
+ '''
110
+ Get all valid sharding specs from source_spec with single all-gather operation, and
111
+ accumulate communication cost on origin cost which will finally be used in auto sharding solver.
112
+ For the all-gather operation, we just care about the S dimension.
113
+
114
+ Argument:
115
+ source_spec(ShardingSpec): the ShardingSpec of the source_spec.
116
+ orig_cost(Dict[str, float]): the original communication cost before this operation.
117
+
118
+ Return:
119
+ valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
120
+
121
+ Example:
122
+ dim_partition_dict = {0: [0], 1: [1]}
123
+ # DistSpec:
124
+ # shard_sequence: S0,S1,R
125
+ # device_mesh_shape: (4, 4)
126
+ sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
127
+ shape_consistency_manager = ShapeConsistencyManager()
128
+ rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
129
+ print(rst_dict)
130
+
131
+ Output:
132
+ {DistSpec:
133
+ shard_sequence: R,S1,R
134
+ device_mesh_shape: (4, 4): 0, DistSpec:
135
+ shard_sequence: S0,R,R
136
+ device_mesh_shape: (4, 4): 0}
137
+ '''
138
+ valid_spec_dict = {}
139
+ comm_pattern = 'all_gather'
140
+ for target_pair in source_spec.dim_partition_dict.items():
141
+ shard_list = self.all_gather_simulator(target_pair)
142
+ index = target_pair[0]
143
+ new_dim_partition_dict = source_spec.dim_partition_dict.copy()
144
+
145
+ # We won't add empty list into dim_partition_dict
146
+ # The key will be popped if the related shard_list is empty
147
+ if shard_list:
148
+ new_dim_partition_dict[index] = shard_list
149
+ else:
150
+ new_dim_partition_dict.pop(index)
151
+
152
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
153
+ gather_dim = index
154
+ logical_process_axis = target_pair[1]
155
+ comm_spec = CommSpec(comm_pattern,
156
+ sharding_spec=source_spec,
157
+ gather_dim=[gather_dim],
158
+ logical_process_axis=[logical_process_axis],
159
+ forward_only=self.forward_only)
160
+
161
+ # compute the communication cost with CommSpec
162
+
163
+ # generate new sharding spec
164
+ new_sharding_spec = ShardingSpec(
165
+ source_spec.device_mesh,
166
+ source_spec.data_type_size,
167
+ source_spec.entire_shape,
168
+ source_spec.max_entire_shape,
169
+ source_spec.raw_shape,
170
+ dim_partition_dict=new_dim_partition_dict)
171
+
172
+ if not new_sharding_spec.sanity_check():
173
+ continue
174
+ cost = comm_spec.get_comm_cost()
175
+ valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
176
+ return valid_spec_dict
177
+
178
+ def get_all_all_to_all_spec(self, source_spec, orig_cost):
179
+ '''
180
+ Get all valid sharding specs from source_spec with single all-to-all operation, and
181
+ accumulate communication cost on origin cost which will finally be used in auto sharding solver.
182
+ For the all-to-all operation, we just care about the pairs containing S dimension.
183
+
184
+ Argument:
185
+ source_spec(ShardingSpec): the ShardingSpec of the source_spec.
186
+ orig_cost(Dict[str, float]): the original communication cost before this operation.
187
+
188
+ Return:
189
+ valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
190
+
191
+ Example:
192
+ dim_partition_dict = {0: [0], 1: [1]}
193
+ # DistSpec:
194
+ # shard_sequence: S0,S1,R
195
+ # device_mesh_shape: (4, 4)
196
+ sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
197
+ shape_consistency_manager = ShapeConsistencyManager()
198
+ rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
199
+ print(rst_dict)
200
+
201
+ Output:
202
+ {DistSpec:
203
+ shard_sequence: S01,R,R
204
+ device_mesh_shape: (4, 4): 0, DistSpec:
205
+ shard_sequence: R,S1,S0
206
+ device_mesh_shape: (4, 4): 0, DistSpec:
207
+ shard_sequence: S0,R,S1
208
+ device_mesh_shape: (4, 4): 0}
209
+ '''
210
+ valid_spec_dict = {}
211
+ comm_pattern = 'all_to_all'
212
+ tensor_dims = len(source_spec.entire_shape)
213
+ for f_index in range(tensor_dims - 1):
214
+ for b_index in range(f_index + 1, tensor_dims):
215
+ # skip (R, R) cases
216
+ if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
217
+ continue
218
+ else:
219
+ if f_index in source_spec.dim_partition_dict:
220
+ '''
221
+ # skip (S01, R) -> (R, S01) is NOT allowed
222
+ if len(source_spec.dim_partition_dict[f_index]) >= 2:
223
+ continue
224
+ '''
225
+ f_target_pair = (f_index, [
226
+ *source_spec.dim_partition_dict[f_index]
227
+ ])
228
+ else:
229
+ f_target_pair = (f_index, [])
230
+ if b_index in source_spec.dim_partition_dict:
231
+ '''
232
+ # skip (R, S01) -> (S01, R) is NOT allowed
233
+ if len(source_spec.dim_partition_dict[b_index]) >= 2:
234
+ continue
235
+ '''
236
+ b_target_pair = (b_index, [
237
+ *source_spec.dim_partition_dict[b_index]
238
+ ])
239
+ else:
240
+ b_target_pair = (b_index, [])
241
+
242
+ # skip (S1, S0) -> S10
243
+ if f_target_pair[1] and b_target_pair[
244
+ 1] and f_target_pair[1][0] >= b_target_pair[1][0]:
245
+ continue
246
+ f_shard_list, b_shard_list = self.all_to_all_simulator(
247
+ f_target_pair, b_target_pair)
248
+ f_index = f_target_pair[0]
249
+ b_index = b_target_pair[0]
250
+
251
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
252
+ if len(f_shard_list) < len(f_target_pair[1]):
253
+ gather_dim = f_index
254
+ shard_dim = b_index
255
+ logical_process_axis = f_target_pair[1]
256
+ else:
257
+ gather_dim = b_index
258
+ shard_dim = f_index
259
+ logical_process_axis = b_target_pair[1]
260
+ comm_spec = CommSpec(
261
+ comm_pattern,
262
+ sharding_spec=source_spec,
263
+ gather_dim=[gather_dim],
264
+ shard_dim=[shard_dim],
265
+ logical_process_axis=[logical_process_axis],
266
+ forward_only=self.forward_only)
267
+
268
+ # compute the communication cost with CommSpec
269
+
270
+ new_dim_partition_dict = source_spec.dim_partition_dict.copy()
271
+
272
+ # We won't add empty list into dim_partition_dict
273
+ # The key will be popped if the related shard_list is empty
274
+ if f_shard_list:
275
+ new_dim_partition_dict[f_index] = f_shard_list
276
+ else:
277
+ new_dim_partition_dict.pop(f_index)
278
+ if b_shard_list:
279
+ new_dim_partition_dict[b_index] = b_shard_list
280
+ else:
281
+ new_dim_partition_dict.pop(b_index)
282
+
283
+ # generate new sharding spec
284
+
285
+ new_sharding_spec = ShardingSpec(
286
+ source_spec.device_mesh,
287
+ source_spec.data_type_size,
288
+ source_spec.entire_shape,
289
+ source_spec.max_entire_shape,
290
+ source_spec.raw_shape,
291
+ dim_partition_dict=new_dim_partition_dict)
292
+ if not new_sharding_spec.sanity_check():
293
+ continue
294
+ cost = comm_spec.get_comm_cost()
295
+ valid_spec_dict[new_sharding_spec] = (comm_spec,
296
+ cost + orig_cost)
297
+
298
+ return valid_spec_dict
299
+
300
+ def get_all_shard_spec(self, source_spec, orig_cost):
301
+ '''
302
+ Get all valid sharding specs from source_spec with single shard operation, and
303
+ accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
304
+ For the sharding operation, we just care about legal sharding dimensions.
305
+
306
+ Argument:
307
+ source_spec(ShardingSpec): the ShardingSpec of the source_spec.
308
+ orig_cost(float): the original communication cost before this operation.
309
+
310
+ Return:
311
+ valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
312
+
313
+ Example:
314
+ dim_partition_dict = {0: [0]}
315
+ # DistSpec:
316
+ # shard_sequence: S0,R,R
317
+ # device_mesh_shape: (4, 4)
318
+ sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
319
+ shape_consistency_manager = ShapeConsistencyManager()
320
+ rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
321
+ print(rst_dict)
322
+
323
+ Output:
324
+ {DistSpec:
325
+ shard_sequence: S01,R,R
326
+ device_mesh_shape: (4, 4): 0, DistSpec:
327
+ shard_sequence: S0,S1,R
328
+ device_mesh_shape: (4, 4): 0, DistSpec:
329
+ shard_sequence: S0,R,S1
330
+ device_mesh_shape: (4, 4): 0}
331
+ '''
332
+ valid_spec_dict = {}
333
+ comm_pattern = 'split'
334
+
335
+ # legal sharding dims means the mesh_id is still available to use.
336
+ legal_sharding_dims = [
337
+ i for i in range(len(source_spec.device_mesh.mesh_shape))
338
+ ]
339
+ for dim, shard_list in source_spec.dim_partition_dict.items():
340
+ for element in shard_list:
341
+ legal_sharding_dims.remove(element)
342
+ if len(legal_sharding_dims) == 0:
343
+ return valid_spec_dict
344
+
345
+ tensor_dims = len(source_spec.entire_shape)
346
+
347
+ for index in range(tensor_dims):
348
+ if index not in source_spec.dim_partition_dict:
349
+ shard_list_list, logical_process_axes = self.shard_simulator(
350
+ (index, []), legal_sharding_dims)
351
+ else:
352
+ shard_list_list, logical_process_axes = self.shard_simulator(
353
+ (index, source_spec.dim_partition_dict[index]),
354
+ legal_sharding_dims)
355
+ if not shard_list_list:
356
+ continue
357
+ for shard_list, logical_process_axis in zip(shard_list_list,
358
+ logical_process_axes):
359
+ new_dim_partition_dict = source_spec.dim_partition_dict.copy()
360
+ new_dim_partition_dict[index] = shard_list
361
+
362
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
363
+ comm_spec = CommSpec(
364
+ comm_pattern,
365
+ sharding_spec=source_spec,
366
+ shard_dim=[index],
367
+ logical_process_axis=[logical_process_axis],
368
+ forward_only=self.forward_only)
369
+
370
+ # generate new sharding spec
371
+ new_sharding_spec = ShardingSpec(
372
+ source_spec.device_mesh,
373
+ source_spec.data_type_size,
374
+ source_spec.entire_shape,
375
+ source_spec.max_entire_shape,
376
+ source_spec.raw_shape,
377
+ dim_partition_dict=new_dim_partition_dict)
378
+ if not new_sharding_spec.sanity_check():
379
+ continue
380
+ # compute the communication cost with CommSpec
381
+ cost = comm_spec.get_comm_cost()
382
+ valid_spec_dict[new_sharding_spec] = (comm_spec,
383
+ cost + orig_cost)
384
+
385
+ return valid_spec_dict
386
+
387
+ def get_all_mixed_shard_spec(self, source_spec, orig_cost):
388
+ '''
389
+ Get all valid sharding specs from source_spec with single shard operation, and
390
+ accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
391
+ For the sharding operation, we just care about legal sharding dimensions.
392
+ '''
393
+ valid_spec_dict = {}
394
+ comm_pattern = 'split'
395
+
396
+ # legal sharding dims means the mesh_id is still available to use.
397
+ legal_sharding_dims = [
398
+ i for i in range(len(source_spec.device_mesh.mesh_shape))
399
+ ]
400
+ for dim, shard_list in source_spec.dim_partition_dict.items():
401
+ for element in shard_list:
402
+ legal_sharding_dims.remove(element)
403
+ if len(legal_sharding_dims) != 2:
404
+ return valid_spec_dict
405
+
406
+ tensor_dims = len(source_spec.entire_shape)
407
+ for f_index in range(tensor_dims):
408
+ for b_index in range(tensor_dims):
409
+ if f_index != b_index:
410
+ shard_dims = [f_index, b_index]
411
+ logical_process_axes = [[legal_sharding_dims[0]],
412
+ [legal_sharding_dims[1]]]
413
+ new_dim_partition_dict = source_spec.dim_partition_dict.copy(
414
+ )
415
+ new_dim_partition_dict[f_index] = [legal_sharding_dims[0]]
416
+ new_dim_partition_dict[b_index] = [legal_sharding_dims[1]]
417
+ comm_spec = CommSpec(
418
+ comm_pattern,
419
+ sharding_spec=source_spec,
420
+ shard_dim=shard_dims,
421
+ logical_process_axis=logical_process_axes,
422
+ forward_only=self.forward_only)
423
+
424
+ # generate new sharding spec
425
+ new_sharding_spec = ShardingSpec(
426
+ source_spec.device_mesh,
427
+ source_spec.data_type_size,
428
+ source_spec.entire_shape,
429
+ source_spec.max_entire_shape,
430
+ source_spec.raw_shape,
431
+ dim_partition_dict=new_dim_partition_dict)
432
+ if not new_sharding_spec.sanity_check():
433
+ continue
434
+ cost = comm_spec.get_comm_cost()
435
+ valid_spec_dict[new_sharding_spec] = (comm_spec,
436
+ cost + orig_cost)
437
+ return valid_spec_dict
438
+
439
+ def get_all_mix_gather_spec(self, source_spec, orig_cost):
440
+ '''
441
+ S0S1 -> RR
442
+ S1S0 -> RR
443
+ S01R -> RR
444
+ RS01 -> RR
445
+ '''
446
+ valid_spec_dict = {}
447
+ comm_pathern = 'all_gather'
448
+ tensor_dims = len(source_spec.entire_shape)
449
+ for f_index in range(tensor_dims - 1):
450
+ for b_index in range(f_index + 1, tensor_dims):
451
+ if (f_index not in source_spec.dim_partition_dict) and (
452
+ b_index not in source_spec.dim_partition_dict):
453
+ continue
454
+ else:
455
+ if f_index in source_spec.dim_partition_dict:
456
+ # skip (S10, R) -> (R, R)
457
+ '''
458
+ if len(
459
+ f_target_pair[1]
460
+ ) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
461
+ continue
462
+ '''
463
+ f_target_pair = (f_index, [
464
+ *source_spec.dim_partition_dict[f_index]
465
+ ])
466
+ else:
467
+ f_target_pair = (f_index, [])
468
+ if b_index in source_spec.dim_partition_dict:
469
+ # skip (R, S10) -> (R, R)
470
+ '''
471
+ if len(
472
+ b_target_pair[1]
473
+ ) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
474
+ continue
475
+ '''
476
+ b_target_pair = (b_index, [
477
+ *source_spec.dim_partition_dict[b_index]
478
+ ])
479
+ else:
480
+ b_target_pair = (b_index, [])
481
+ if len(f_target_pair[1]) + len(b_target_pair[1]) != 2:
482
+ continue
483
+ gather_dim, logical_process_axes = self.mix_gather_simulator(
484
+ f_target_pair, b_target_pair)
485
+ comm_spec = CommSpec(comm_pathern,
486
+ sharding_spec=source_spec,
487
+ gather_dim=gather_dim,
488
+ logical_process_axis=logical_process_axes,
489
+ forward_only=self.forward_only,
490
+ mix_gather=True)
491
+
492
+ new_dim_partition_dict = {}
493
+ # generate new sharding spec
494
+ new_sharding_spec = ShardingSpec(
495
+ source_spec.device_mesh,
496
+ source_spec.data_type_size,
497
+ source_spec.entire_shape,
498
+ source_spec.max_entire_shape,
499
+ source_spec.raw_shape,
500
+ dim_partition_dict=new_dim_partition_dict)
501
+ if not new_sharding_spec.sanity_check():
502
+ continue
503
+ cost = comm_spec.get_comm_cost()
504
+ valid_spec_dict[new_sharding_spec] = (comm_spec,
505
+ cost + orig_cost)
506
+
507
+ return valid_spec_dict
508
+
509
+ def get_all_one_step_transform_spec(self, source_spec, orig_cost):
510
+ '''
511
+ Get all valid sharding specs from source_spec with one step transform, and
512
+ accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
513
+ Note:
514
+ all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
515
+ and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
516
+ we could safely put them together.
517
+
518
+ Argument:
519
+ source_spec(ShardingSpec): the ShardingSpec of the source_spec.
520
+ orig_cost(float): the original communication cost before this operation.
521
+
522
+ Return:
523
+ valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
524
+ '''
525
+ valid_spec_dict = {}
526
+ valid_spec_dict.update(
527
+ self.get_all_all_gather_spec(source_spec, orig_cost))
528
+ valid_spec_dict.update(
529
+ self.get_all_all_to_all_spec(source_spec, orig_cost))
530
+ valid_spec_dict.update(
531
+ self.get_all_mix_gather_spec(source_spec, orig_cost))
532
+ valid_spec_dict.update(
533
+ self.get_all_mixed_shard_spec(source_spec, orig_cost))
534
+ valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
535
+ return valid_spec_dict
536
+
537
+ def mem_cost(self, comm_action_sequence: List[CommSpec], mem_pattern='opt'):
538
+ """memory cost of the communication action sequence
539
+
540
+ Args:
541
+ comm_action_sequence (List[CommSpec]): list of communication actions
542
+
543
+ Returns:
544
+ TrainCycleItem: memory (numel) cost of such comm_action_sequence
545
+ """
546
+
547
+ def compute_shape(sharding_spec: ShardingSpec):
548
+ if 'opt' == mem_pattern:
549
+ return sharding_spec.get_sharded_shape_per_device()
550
+ elif 'max' == mem_pattern:
551
+ return sharding_spec.get_max_sharded_shape_per_device()
552
+ else:
553
+ return 0.0
554
+
555
+ def gather_analysis(comm_spec, peak_mem):
556
+ """analyze all_gather memory footprint
557
+ all_gather will allocate memory for the output tensor, and there will be temp memory for
558
+ all_gather operation, which is twice the size of output tensor
559
+
560
+ Args:
561
+ comm_spec (CommSpec): input CommSpec
562
+ """
563
+ input_shape = compute_shape(comm_spec.sharding_spec)
564
+ input_numel = reduce(operator.mul, input_shape, 1)
565
+ for axes in comm_spec.logical_process_axis:
566
+ for axis in axes:
567
+ output_numel = input_numel * comm_spec.device_mesh.mesh_shape[
568
+ axis]
569
+ alloc_mem = (input_numel +
570
+ output_numel * 2) * comm_spec.sharding_spec.dtype_size
571
+ peak_mem = max(peak_mem, alloc_mem)
572
+ return peak_mem
573
+
574
+ def reduce_scatter_analysis(comm_spec, peak_mem):
575
+
576
+ input_shape = compute_shape(comm_spec.sharding_spec)
577
+ input_numel = reduce(operator.mul, input_shape, 1)
578
+ output_numel = input_numel
579
+ for axes in comm_spec.logical_process_axis:
580
+ for axis in axes:
581
+ output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
582
+ axis]
583
+ alloc_mem = (input_numel +
584
+ output_numel * 2) * comm_spec.sharding_spec.dtype_size
585
+ peak_mem = max(peak_mem, alloc_mem)
586
+
587
+ return peak_mem
588
+
589
+ def split_analysis(comm_spec: CommSpec, peak_mem: int):
590
+ """analyze split memory footprint
591
+ split will allocate memory for the output tensor if we don't apply shard on the first dimension of
592
+ the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
593
+ generate new tensor in this case, so no memory will be allocated.
594
+
595
+ Args:
596
+ comm_spec (CommSpec): input CommSpec
597
+ discard_input (bool): whether to discard the input tensor
598
+ alloc_numel (int): current allocated numel
599
+ peak_numel (int): current peak numel
600
+ """
601
+ shard_dim = comm_spec.shard_dim
602
+ if shard_dim != 0:
603
+ # if we don't shard the tensor on the first dimension, the split action will
604
+ # generate a new tensor
605
+ input_shape = compute_shape(comm_spec.sharding_spec)
606
+ input_numel = reduce(operator.mul, input_shape, 1)
607
+ output_numel = input_numel
608
+ for axes in comm_spec.logical_process_axis:
609
+ for axis in axes:
610
+ output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
611
+ axis]
612
+ alloc_mem = (input_numel +
613
+ output_numel) * comm_spec.sharding_spec.dtype_size
614
+ peak_mem = max(peak_mem, alloc_mem)
615
+ else:
616
+ # if we shard the tensor on the first dimension, the split action will not generate
617
+ # a new tensor, and as it will preserve a reference to the input tensor, we could
618
+ # override the discard_input option here
619
+ # NOTE: this special case might fail in some weird cases, e.g. if we have three split
620
+ # actions in the comm actions sequence, the first split action operate on the second dimension,
621
+ # the second split action operate on the first dimension, and the third split action operate, again,
622
+ # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
623
+ # memory the same size as the output of first split action. However, the third split action will discard
624
+ # the input tensor, and it actually should discard the tensor generated by the first split action, so in
625
+ # the current memory estimation framework, we will overestimate the memory usage. But the above case is
626
+ # kind of weird, and I think we could ignore it for now.
627
+ pass
628
+ return peak_mem
629
+
630
+ def reduce_analysis(comm_spec: CommSpec, peak_mem: int):
631
+ input_shape = compute_shape(comm_spec.sharding_spec)
632
+ input_numel = reduce(operator.mul, input_shape, 1)
633
+ output_numel = input_numel
634
+ alloc_mem = (input_numel +
635
+ output_numel) * comm_spec.sharding_spec.dtype_size
636
+ peak_mem = max(peak_mem, alloc_mem)
637
+ return peak_mem
638
+
639
+ def all2all_analysis(comm_spec: CommSpec, peak_mem: int):
640
+ input_shape = compute_shape(comm_spec.sharding_spec)
641
+ input_numel = reduce(operator.mul, input_shape, 1)
642
+ output_numel = input_numel
643
+ comm_spec.shard_dim
644
+ alloc_mem = (input_numel +
645
+ output_numel * 3) * comm_spec.sharding_spec.dtype_size
646
+ peak_mem = max(peak_mem, alloc_mem)
647
+ return peak_mem
648
+
649
+ def peer_to_peer_analysis(comm_spec: CommSpec, peak_mem: int):
650
+ input_shape = compute_shape(comm_spec.sharding_spec)
651
+ input_numel = reduce(operator.mul, input_shape, 1)
652
+ alloc_mem = (input_numel) * comm_spec.sharding_spec.dtype_size
653
+ peak_mem = max(peak_mem, alloc_mem)
654
+ return peak_mem
655
+
656
+ pattern_to_func_dict = {
657
+ 'all_gather': gather_analysis,
658
+ 'all_to_all': all2all_analysis,
659
+ 'split': split_analysis,
660
+ 'all_reduce': reduce_analysis,
661
+ 'reduce_scatter': reduce_scatter_analysis,
662
+ 'peer_to_peer': peer_to_peer_analysis
663
+ }
664
+
665
+ fwd_actions = []
666
+ # construct forward and backward comm actions sequence
667
+ for comm_spec in comm_action_sequence:
668
+ fwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
669
+ fwd_actions.append(fwd_action)
670
+
671
+ # analyze memory footprint of forward comm actions sequence
672
+ fwd_peak_numel = 0
673
+ for idx, action_spec_pair in enumerate(
674
+ zip(fwd_actions, comm_action_sequence)):
675
+ # the first forward comm action will not discard input
676
+ fwd_action, comm_spec = action_spec_pair
677
+ fwd_peak_numel = fwd_action(comm_spec, fwd_peak_numel)
678
+
679
+ return fwd_peak_numel
680
+
681
+ def print_shape_consistency_result(self,
682
+ transform_path,
683
+ comm_action_sequence,
684
+ resharding_cost,
685
+ file=None):
686
+ for idx, tpath in enumerate(transform_path):
687
+ print(
688
+ f'sharding_info = [op_shape:{tpath.entire_shape}, sharding_spec:{tpath.sharding_sequence}, sharded_shape:{tpath.get_sharded_shape_per_device()}]',
689
+ end=" ",
690
+ file=file)
691
+ print('->', end=" ", file=file)
692
+ try:
693
+ commspec = comm_action_sequence[idx]
694
+ comm = [
695
+ commspec.comm_pattern, commspec.gather_dim,
696
+ commspec.shard_dim, commspec.logical_process_axis
697
+ ]
698
+ except:
699
+ comm = ''
700
+ print(f'comm_info = {comm}', end=" ", file=file)
701
+ print('->', end=" ", file=file)
702
+ print(f'total_cost = {resharding_cost}', file=file)
703
+
704
+ def construct_transform_path_from_cache(self, src_spec, target_spec,
705
+ old_transform_path,
706
+ old_comm_action_sequence,
707
+ orig_cost):
708
+ new_transform_path = [src_spec]
709
+ new_comm_action_sequence = []
710
+ new_cost = orig_cost
711
+ new_src_spec = src_spec
712
+ for idx, old_comm_spec in enumerate(old_comm_action_sequence):
713
+ new_comm_spec = CommSpec(
714
+ old_comm_spec.comm_pattern,
715
+ sharding_spec=new_src_spec,
716
+ gather_dim=old_comm_spec.gather_dim,
717
+ shard_dim=old_comm_spec.shard_dim,
718
+ logical_process_axis=old_comm_spec.logical_process_axis,
719
+ forward_only=old_comm_spec.forward_only,
720
+ mix_gather=old_comm_spec.mix_gather)
721
+ new_comm_action_sequence.append(new_comm_spec)
722
+ new_cost += new_comm_spec.get_comm_cost()
723
+ old_target_spec = old_transform_path[idx + 1]
724
+ new_target_spec = ShardingSpec(new_src_spec.device_mesh,
725
+ new_src_spec.data_type_size,
726
+ new_src_spec.entire_shape,
727
+ new_src_spec.max_entire_shape,
728
+ new_src_spec.raw_shape,
729
+ old_target_spec.dim_partition_dict)
730
+ new_transform_path.append(new_target_spec)
731
+ new_src_spec = new_target_spec
732
+ assert new_transform_path[-1].get_sharded_shape_per_device(
733
+ ) == target_spec.get_sharded_shape_per_device(
734
+ ), 'failed to insert the cache transform path'
735
+ return new_transform_path, new_comm_action_sequence, new_cost
736
+
737
+ def shape_consistency(
738
+ self, source_spec: ShardingSpec, target_spec: ShardingSpec
739
+ ) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
740
+ '''
741
+ This method will find a path to transform source_spec to target_spec with
742
+ a greedy algorithm.
743
+ The basic idea is:
744
+ Step1:
745
+ Generate all one-step transform sequences from source_spec.
746
+ Step2:
747
+ Pick the 'best' sharding spec following the heuristic function.
748
+ Step3:
749
+ Repeat above steps until the source spec transform to target spec.
750
+ '''
751
+ MAX_TRANSFORM_STEPS = 20
752
+ total_cost = 0.0
753
+ total_steps = 0
754
+ transform_path = []
755
+ comm_action_sequence = []
756
+ # We do nothing if the sharding spec is all the same.
757
+ if source_spec.sharding_sequence_difference(target_spec) == 0:
758
+ return (transform_path, comm_action_sequence, total_cost)
759
+
760
+ spec_pairs = (str(source_spec.sharding_sequence),
761
+ str(target_spec.sharding_sequence))
762
+
763
+ if spec_pairs in self.cached_spec_pairs_transform_path:
764
+ transform_path, comm_action_sequence = self.cached_spec_pairs_transform_path[
765
+ spec_pairs]
766
+ new_transform_path, new_comm_action_sequence, new_total_cost = self.construct_transform_path_from_cache(
767
+ source_spec, target_spec, transform_path, comm_action_sequence,
768
+ total_cost)
769
+ self.cache_hit += 1
770
+ return (new_transform_path, new_comm_action_sequence,
771
+ new_total_cost)
772
+
773
+ else:
774
+ self.cache_miss += 1
775
+
776
+ temp_sharding_spec = source_spec
777
+ transform_path.append(temp_sharding_spec)
778
+ # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
779
+ while total_steps <= MAX_TRANSFORM_STEPS:
780
+ valid_transform_spec_dict = self.get_all_one_step_transform_spec(
781
+ temp_sharding_spec, total_cost)
782
+ best_difference_score = math.inf
783
+
784
+ for sharding_spec, info_pairs in valid_transform_spec_dict.items():
785
+ comm_spec, cost = info_pairs
786
+ spec_difference = sharding_spec.sharding_sequence_difference(
787
+ target_spec)
788
+
789
+ if spec_difference == 0:
790
+ total_cost = cost
791
+ transform_path.append(sharding_spec)
792
+ comm_action_sequence.append(comm_spec)
793
+ self.cached_spec_pairs_transform_path[spec_pairs] = (
794
+ transform_path, comm_action_sequence)
795
+ return (transform_path, comm_action_sequence, total_cost)
796
+
797
+ if spec_difference < best_difference_score:
798
+ temp_sharding_spec = sharding_spec
799
+ temp_cost = cost
800
+ temp_comm_spec = comm_spec
801
+ best_difference_score = spec_difference
802
+
803
+ transform_path.append(temp_sharding_spec)
804
+ comm_action_sequence.append(temp_comm_spec)
805
+ total_cost = temp_cost
806
+ total_steps += 1
807
+
808
+ raise RuntimeError(
809
+ f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps."
810
+ )
811
+
812
+ def dum_transform_path_from_cache(self):
813
+ src_specs, tgt_specs, path_strs = [], [], []
814
+ for spec_pairs, trans_comm_path in self.cached_spec_pairs_transform_path.items(
815
+ ):
816
+ src_specs.append(spec_pairs[0])
817
+ tgt_specs.append(spec_pairs[1])
818
+ trans_paths, comm_specs = trans_comm_path[0], trans_comm_path[1]
819
+ path_str = f'{spec_pairs[0]}->'
820
+ for idx in range(1, len(trans_paths)):
821
+ comm_spec = comm_specs[idx - 1]
822
+ comm_str = f'{comm_spec.comm_pattern}: gather_dim{comm_spec.gather_dim}, shard_dim{comm_spec.shard_dim}, mesh_axis{comm_spec.logical_process_axis}->'
823
+ path_str += comm_str
824
+ path_str += f'{trans_paths[idx].sharding_sequence}->'
825
+ path_strs.append(path_str)
826
+ ret_dict = {
827
+ 'src_spec': src_specs,
828
+ 'dst_specs': tgt_specs,
829
+ 'trans_path': path_strs
830
+ }
831
+ ret_df = pd.DataFrame.from_dict(ret_dict)
832
+ return ret_df
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shape_node.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Shape(Node):
8
+
9
+ def _update_memory_cost(self, strategies):
10
+ pass
11
+
12
+ def _collect_strategies(self, device_mesh):
13
+ # one input for softmax node
14
+ predecessor = self.predecessor_nodes[0]
15
+ strategies_vector = StrategiesVector(self)
16
+ for idx, strategy in enumerate(predecessor.strategies_vector):
17
+ # current node's local name input0 -> global name xxx
18
+ global_input_name = self.op_data['input0'].name
19
+ # global name xxx -> pre node local output name
20
+ prenode_local_name = predecessor.global_to_local_op_name[
21
+ global_input_name]
22
+ dim_partition_dict = copy.deepcopy(
23
+ strategy.sharding_specs[prenode_local_name].dim_partition_dict)
24
+ in0_partition_dict = dim_partition_dict
25
+ dim_partition_dict_mapping = {
26
+ "input0": in0_partition_dict,
27
+ "output0": {},
28
+ }
29
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
30
+ dim_partition_dict_mapping, device_mesh)
31
+ if 0 == len(sharding_spec_mapping):
32
+ return strategies_vector
33
+ name = '{} = <shape> {}'.format(
34
+ sharding_spec_mapping['output0'].sharding_sequence,
35
+ sharding_spec_mapping['input0'].sharding_sequence)
36
+ sharding_strategy = self._get_sharding_strategy(
37
+ name=name,
38
+ sharding_spec_mapping=sharding_spec_mapping,
39
+ communication_action_mapping={})
40
+ strategies_vector.append(sharding_strategy)
41
+ return strategies_vector
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_spec.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from functools import reduce
3
+
4
+ import tensorrt as trt
5
+
6
+ from tensorrt_llm.logger import logger
7
+
8
+ ALLGATHER_COST = 20
9
+ SHARD_COST = 5
10
+ STEP_PENALTY = 6
11
+ NAN = 'nan'
12
+
13
+
14
+ def _convert_str_to_shard_list(str_spec):
15
+ '''
16
+ Convert str_spec into shard_list.
17
+
18
+ Argument:
19
+ str_spec(str): dim spec in str type.
20
+ '''
21
+
22
+ if str_spec == 'R':
23
+ return []
24
+ if str_spec == 'S0':
25
+ return [0]
26
+ if str_spec == 'S1':
27
+ return [1]
28
+ if str_spec == 'S01':
29
+ return [0, 1]
30
+
31
+
32
+ def _build_difference_2d_dict():
33
+ '''
34
+ Build a difference mapping for 2D device mesh case. It will be used to
35
+ compute the difference between DimSpec pairs.
36
+ '''
37
+
38
+ source_spec_list = ['R', 'S0', 'S1', 'S01']
39
+ target_spec_list = ['R', 'S0', 'S1', 'S01']
40
+ difference_dict = {}
41
+ for source_spec in source_spec_list:
42
+ for target_spec in target_spec_list:
43
+ spec_pair = (source_spec, target_spec)
44
+ source_shard_list = _convert_str_to_shard_list(source_spec)
45
+ target_shard_list = _convert_str_to_shard_list(target_spec)
46
+
47
+ # source same as target
48
+ if source_shard_list == target_shard_list:
49
+ difference = 0
50
+
51
+ # all_gather(source) -> target
52
+ elif len(source_shard_list) == len(
53
+ target_shard_list
54
+ ) + 1 and source_shard_list[:-1] == target_shard_list:
55
+ difference = ALLGATHER_COST
56
+
57
+ # shard(source) -> target
58
+ elif len(source_shard_list) == len(
59
+ target_shard_list
60
+ ) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
61
+ -1] not in source_shard_list:
62
+ difference = SHARD_COST
63
+
64
+ # S1 -> S0 or S0 -> S1
65
+ elif len(source_shard_list) == len(target_shard_list):
66
+ # source -> R -> target
67
+ difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
68
+
69
+ # R -> S01
70
+ elif len(source_shard_list) == len(target_shard_list) - 2:
71
+ difference = SHARD_COST + STEP_PENALTY + SHARD_COST
72
+
73
+ # S01 -> R
74
+ elif len(source_shard_list) == len(target_shard_list) + 2:
75
+ difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
76
+
77
+ # S1 -> S01
78
+ elif len(source_shard_list) == len(target_shard_list) - 1:
79
+ difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
80
+
81
+ # S01 -> S1
82
+ elif len(source_shard_list) == len(target_shard_list) + 1:
83
+ difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
84
+
85
+ else:
86
+ difference = NAN
87
+ difference_dict[spec_pair] = difference
88
+
89
+ return difference_dict
90
+
91
+
92
+ _difference_dict = _build_difference_2d_dict()
93
+
94
+
95
+ class DimSpec:
96
+ '''
97
+ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
98
+ logical device mesh and give a method to compute the difference between them.
99
+
100
+ Argument:
101
+ shard_list(List[int]): if shard_list is empty, the dim spec will be 'R' type.
102
+ Otherwise, the element in shard_list means the data will be sharded in that dimension.
103
+ '''
104
+
105
+ def __init__(self, shard_list):
106
+ self.is_replica = len(shard_list) == 0
107
+ self.shard_list = shard_list
108
+
109
+ def __eq__(self, other):
110
+ return str(self) == str(other)
111
+
112
+ def __repr__(self):
113
+ if self.is_replica:
114
+ return 'R'
115
+ target = 'S'
116
+ for dim in self.shard_list:
117
+ target += str(dim)
118
+ return target
119
+
120
+ def difference(self, other):
121
+ '''
122
+ The difference between two DimSpec.
123
+
124
+ Argument:
125
+ other(DimSpec): the dim spec to compare with.
126
+
127
+ Return:
128
+ difference(int): the difference between two DimSpec.
129
+
130
+ Example:
131
+ dim_spec = DimSpec([0])
132
+ other_dim_spec = DimSpec([0, 1])
133
+ print(dim_spec.difference(other_dim_spec))
134
+
135
+ Output:
136
+ 5
137
+ '''
138
+ difference = _difference_dict[(str(self), str(other))]
139
+ return difference
140
+
141
+
142
+ def get_sharding_sequence(num_dims, dims, device_dims):
143
+ sharding_sequence = [DimSpec([])] * num_dims
144
+ for dim, shard_list in zip(dims, device_dims):
145
+ sharding_sequence[dim] = DimSpec(shard_list)
146
+ return sharding_sequence
147
+
148
+
149
+ class ShardingSpec:
150
+ '''
151
+ Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
152
+ to, the entire shape of the tensor before sharded, and the sharding sequence looks like
153
+ [R, R, S0, S1].
154
+
155
+ Argument:
156
+ device_mesh: A logical view of a physical mesh.
157
+ entire_shape: The entire shape of tensor before sharded.
158
+ dim_partition_dict: The key is the dimension of tensor to be sharded,
159
+ and the value of the key describe which logical axis will be sharded in that dimension.
160
+ sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
161
+ '''
162
+
163
+ def __init__(self,
164
+ device_mesh,
165
+ data_type_size,
166
+ data_shape,
167
+ max_data_shape,
168
+ raw_data_shape,
169
+ dim_partition_dict=None,
170
+ sharding_sequence=None):
171
+ self.device_mesh = device_mesh
172
+ self.data_type_size = data_type_size
173
+ self.dtype = data_type_size[0]
174
+ self.dtype_size = data_type_size[1]
175
+ self.entire_shape = data_shape
176
+ self.max_entire_shape = max_data_shape
177
+ self.raw_shape = raw_data_shape
178
+ self.dim_partition_dict = dim_partition_dict
179
+ self.sharding_sequence = sharding_sequence
180
+ self.enable_shard_unbalanced_shape = device_mesh.config.enable_shard_unbalanced_shape
181
+ self.enable_shard_dynamic_shape = device_mesh.config.enable_shard_dynamic_shape
182
+ if self.sharding_sequence is None:
183
+ self.dim_partition_dict = self._merge_same_dim_mesh_list(
184
+ len(self.entire_shape), self.dim_partition_dict)
185
+ self.dim_partition_dict = self._convert_dim_partition_dict(
186
+ len(self.entire_shape), self.dim_partition_dict)
187
+ assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
188
+ self.convert_dict_to_shard_sequence()
189
+ elif self.dim_partition_dict is None:
190
+ assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
191
+ self.convert_shard_sequence_to_dict()
192
+ self.dim_partition_dict = self._merge_same_dim_mesh_list(
193
+ len(self.entire_shape), self.dim_partition_dict)
194
+ self.dim_partition_dict = self._convert_dim_partition_dict(
195
+ len(self.entire_shape), self.dim_partition_dict)
196
+
197
+ self.sharded_shape, self.max_sharded_shape = [*self.entire_shape], [
198
+ *self.max_entire_shape
199
+ ]
200
+ for dim, shard_list in self.dim_partition_dict.items():
201
+ mesh_list = [
202
+ self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list
203
+ ]
204
+ shard_partitions = reduce(operator.mul, mesh_list, 1)
205
+ self.sharded_shape[dim] = (self.sharded_shape[dim] +
206
+ shard_partitions - 1) // shard_partitions
207
+ self.max_sharded_shape[dim] = (self.max_sharded_shape[dim] +
208
+ shard_partitions -
209
+ 1) // shard_partitions
210
+
211
+ def print_spec(self, file=None):
212
+ print(
213
+ f"sharding_sequence = {self.sharding_sequence}, shape = {self.get_sharded_shape_per_device()}",
214
+ file=file,
215
+ )
216
+
217
+ def _merge_same_dim_mesh_list(self, dim_size, dim_partition_dict):
218
+ '''
219
+ This method is used to merge the different key value which points to same physical position.
220
+
221
+ For example:
222
+ dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
223
+ In this method, above dim_partition_dict will be converted to {1: [0, 1]}
224
+ '''
225
+ converted_dim_partition_dict = {}
226
+ for dim, mesh_list in dim_partition_dict.items():
227
+ if dim < 0:
228
+ dim = dim_size + dim
229
+ if dim not in converted_dim_partition_dict:
230
+ converted_dim_partition_dict[dim] = mesh_list
231
+ else:
232
+ converted_dim_partition_dict[dim].extend(mesh_list)
233
+ converted_dim_partition_dict[dim].sort()
234
+ return converted_dim_partition_dict
235
+
236
+ def _convert_dim_partition_dict(self, dim_size, dim_partition_dict):
237
+ dims_to_convert = []
238
+ for dim, mesh_list in dim_partition_dict.items():
239
+ if dim < 0:
240
+ dims_to_convert.append(dim)
241
+ for dim in dims_to_convert:
242
+ dim_partition_dict.pop(dim)
243
+ dim_partition_dict[dim_size + dim] = mesh_list
244
+ return dim_partition_dict
245
+
246
+ def _remove_mesh_dim_one(self, dim_partition_dict):
247
+ dims_to_remove = []
248
+ for dim, mesh_list in dim_partition_dict.items():
249
+ new_mesh_list = []
250
+ for mesh_dim in mesh_list:
251
+ if self.device_mesh.mesh_shape[mesh_dim] != 1:
252
+ new_mesh_list.append(mesh_dim)
253
+ if 0 != len(new_mesh_list):
254
+ dim_partition_dict[dim] = new_mesh_list
255
+ else:
256
+ dims_to_remove.append(dim)
257
+ for dim in dims_to_remove:
258
+ dim_partition_dict.pop(dim)
259
+ return dim_partition_dict
260
+
261
+ def __repr__(self):
262
+ res = "DistSpec("
263
+ res += f"shard_sequence={self.sharding_sequence},"
264
+ res += f"shape={self.device_mesh.mesh_shape}"
265
+ res += ")"
266
+ return res
267
+
268
+ def sanity_check(self):
269
+ # make sure all axes in logical device mesh only be used once
270
+ dim_check_list = [*range(len(self.device_mesh.mesh_shape))]
271
+ for dim, shard_list in self.dim_partition_dict.items():
272
+ for element in shard_list:
273
+ if element in dim_check_list:
274
+ dim_check_list.remove(element)
275
+ else:
276
+ logger.warning(
277
+ f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}. dim_partition_dict={self.dim_partition_dict}"
278
+ )
279
+ return False
280
+
281
+ # make sure that the dimension is not out of index
282
+ for dim in self.dim_partition_dict.keys():
283
+ # we have tried to convert the negative value to positive value, if it is larger than the dim_size or negative still, it is out of index
284
+ if dim >= len(self.entire_shape) or dim < 0:
285
+ print(
286
+ f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
287
+ )
288
+ return False
289
+
290
+ if not self.enable_shard_dynamic_shape:
291
+ # make sure to not to shard on dynamic shape
292
+ for dim, shard_list in self.dim_partition_dict.items():
293
+ if len(shard_list) == 0:
294
+ continue
295
+ if len(self.raw_shape) == 0:
296
+ continue
297
+ if -1 == self.raw_shape[dim]:
298
+ return False
299
+
300
+ # make sure that the sharding for a dimension is divisible by the number of devices
301
+ for dim, shard_list in self.dim_partition_dict.items():
302
+ if len(shard_list) == 0:
303
+ continue
304
+ tensor_dim_size = self.entire_shape[dim]
305
+ num_devices = 1
306
+
307
+ for element in shard_list:
308
+ num_devices *= self.device_mesh.mesh_shape[element]
309
+ if num_devices == 1:
310
+ # we only support RR when the device is 1
311
+ return False
312
+
313
+ if not self.enable_shard_unbalanced_shape:
314
+ if tensor_dim_size % num_devices != 0 or tensor_dim_size == 1:
315
+ '''
316
+ print(
317
+ f'The size of static dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
318
+ )
319
+ '''
320
+ return False
321
+ else:
322
+ if tensor_dim_size == 1:
323
+ return False
324
+ '''
325
+ if self.get_sharded_size_per_device() > (2**31 - 1):
326
+ print(
327
+ f'memory footprint per device {self.get_sharded_size_per_device()} is larger than 2**31 - 1'
328
+ )
329
+ return False
330
+ '''
331
+ return True
332
+
333
+ def convert_dict_to_shard_sequence(self):
334
+ '''
335
+ Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
336
+ '''
337
+ sharding_sequence = [DimSpec([])] * len(self.entire_shape)
338
+ for dim, shard_list in self.dim_partition_dict.items():
339
+ sharding_sequence[dim] = DimSpec(shard_list)
340
+ self.sharding_sequence = sharding_sequence
341
+
342
+ def convert_shard_sequence_to_dict(self):
343
+ '''
344
+ Convert sharding_sequence into dim_partition_dict.
345
+ '''
346
+ new_dim_partition_dict = {}
347
+ for index, dim_spec in enumerate(self.sharding_sequence):
348
+ if not dim_spec.is_replica:
349
+ if index not in new_dim_partition_dict:
350
+ new_dim_partition_dict[index] = []
351
+ new_dim_partition_dict[index].extend(dim_spec.shard_list)
352
+ self.dim_partition_dict = new_dim_partition_dict
353
+
354
+ def sharding_sequence_difference(self, other):
355
+ '''
356
+ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
357
+ pair of sharding sequence.
358
+
359
+ Example:
360
+ dim_partition_dict = {0: [0, 1]}
361
+ # DistSpec:
362
+ # shard_sequence: S01,R,R
363
+ # device_mesh_shape: (4, 4)
364
+ sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
365
+ dim_partition_dict_to_compare = {0: [0], 1: [1]}
366
+ # DistSpec:
367
+ # shard_sequence: S0,S1,R
368
+ # device_mesh_shape: (4, 4)
369
+ sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
370
+ print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
371
+
372
+ Output:
373
+ 25
374
+
375
+ Argument:
376
+ other(ShardingSpec): The ShardingSpec to compared with.
377
+
378
+ Return:
379
+ difference(int): Difference between two ShardingSpec.
380
+ '''
381
+ assert len(self.sharding_sequence) == len(
382
+ other.sharding_sequence
383
+ ), f'Cannot compare difference for two sharding specs with different length.'
384
+ difference = 0
385
+ for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence,
386
+ other.sharding_sequence):
387
+ difference += orig_dim_spec.difference(other_dim_spec)
388
+ return difference
389
+
390
+ def get_sharded_shape_per_device(self, ):
391
+ return self.sharded_shape
392
+
393
+ def get_sharded_element_per_device(self, ):
394
+ sharded_shape = self.get_sharded_shape_per_device()
395
+ if len(sharded_shape) == 0:
396
+ num_elements = 1
397
+ else:
398
+ num_elements = trt.volume(sharded_shape)
399
+ return num_elements
400
+
401
+ def get_sharded_size_per_device(self, ):
402
+ num_elements = self.get_sharded_element_per_device()
403
+ return num_elements * self.dtype_size
404
+
405
+ def get_max_sharded_shape_per_device(self, ):
406
+ return self.max_sharded_shape
407
+
408
+ def get_max_sharded_element_per_device(self, ):
409
+ max_sharded_shape = self.get_max_sharded_shape_per_device()
410
+ if len(max_sharded_shape) == 0:
411
+ num_elements = 1
412
+ else:
413
+ num_elements = trt.volume(max_sharded_shape)
414
+ return num_elements
415
+
416
+ def get_max_sharded_size_per_device(self, ):
417
+ num_elements = self.get_max_sharded_element_per_device()
418
+ return num_elements * self.dtype_size
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/sharding_strategy.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ShardingStrategy(object):
2
+
3
+ def __init__(self,
4
+ name=None,
5
+ sharding_specs=None,
6
+ communication_actions=None):
7
+ self.name = name or ""
8
+ self.sharding_specs = sharding_specs or {}
9
+ self.communication_actions = communication_actions
10
+ self.sharding_cost = 0
11
+ self.communication_cost = 0
12
+ self.resharding_costs = {}
13
+ self.best_resharding_cost = {}
14
+ self.node_names = {}
15
+
16
+ self.comm_buff_memory_footprint = 0
17
+ self.inout_memory_footprint = 0
18
+ self.const_memory_footprint = 0
19
+ self.peak_memory_footprint = 0
20
+ self.computation_macs = 0
21
+ self.alpha_beta_cost = 0
22
+
23
+ def print_strategy(self, best_resharding_cost_only=False, file=None):
24
+
25
+ def print_resharding_costs(resharding_cost):
26
+ for prenode_node_name, rcosts in resharding_cost.items():
27
+ if isinstance(prenode_node_name, int):
28
+ idx = prenode_node_name
29
+ prenode_node_name = self.node_names[idx]
30
+ print(f' pre_node = {idx} {prenode_node_name}',
31
+ file=file)
32
+ else:
33
+ print(f' pre_node = {prenode_node_name}', file=file)
34
+ for idx, rcost in enumerate(rcosts):
35
+ transpaths, commspecs, cost = rcost
36
+ print(f' {idx}: ', end=' ', file=file)
37
+ device_mesh.shape_consistency_manager.print_shape_consistency_result(
38
+ transpaths, commspecs, cost, file)
39
+
40
+ print(f'name = {self.name}', file=file)
41
+ print(f'sharding_cost = {self.sharding_cost}', file=file)
42
+ print(
43
+ f'communication_buffer_memory_footprint = {self.comm_buff_memory_footprint}, communication_cost = {self.communication_cost}',
44
+ file=file)
45
+ print(f'inout_memory_footprint = {self.inout_memory_footprint}',
46
+ file=file)
47
+ print(f'peak_memory_footprint = {self.peak_memory_footprint}',
48
+ file=file)
49
+ print(f'const_memory_footprint = {self.const_memory_footprint}',
50
+ file=file)
51
+ print('sharding_specs:', file=file)
52
+ device_mesh = None
53
+ for specname, spec in self.sharding_specs.items():
54
+ print(specname + ', ', end=' ', file=file)
55
+ spec.print_spec(file)
56
+ device_mesh = spec.device_mesh
57
+
58
+ if best_resharding_cost_only and self.best_resharding_cost:
59
+ print('best_resharding_costs:', file=file)
60
+ print_resharding_costs(self.best_resharding_cost)
61
+ else:
62
+ print('resharding costs:', file=file)
63
+ print_resharding_costs(self.resharding_costs)
64
+
65
+
66
+ class StrategiesVector(list):
67
+ '''
68
+ Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
69
+ strategies of the node.
70
+
71
+ Argument:
72
+ node (Node): node for which the list of sharding strategies are generated.
73
+ '''
74
+
75
+ def __init__(self, node):
76
+ super().__init__()
77
+ self.node = node
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/shuffle_node.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Shuffle(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ self.first_tanspose_dims = layer.as_trt().first_transpose
13
+ self.second_transpose_dims = layer.as_trt().second_transpose
14
+ self.zero_is_placeholder = layer.as_trt().zero_is_placeholder
15
+ self.is_first_transepose_identity = (sorted(
16
+ self.first_tanspose_dims) == list(self.first_tanspose_dims))
17
+ self.input_shape = self.get_input(0).shape
18
+ self.is_second_transepose_identity = (sorted(
19
+ self.second_transpose_dims) == list(self.second_transpose_dims))
20
+
21
+ output_shape = list(self.get_output(0).shape)
22
+ self.reshape_dims = copy.deepcopy(output_shape)
23
+ if not self.is_second_transepose_identity:
24
+ for i in self.second_transpose_dims:
25
+ if self.second_transpose_dims[i] != i:
26
+ self.reshape_dims[
27
+ self.second_transpose_dims[i]] = output_shape[i]
28
+ self.is_reshape_identity = (list(self.reshape_dims) == list(
29
+ self.input_shape))
30
+ layer.to_base_class()
31
+
32
+ def _collect_transpose_strategies(self, device_mesh, transpose_dims):
33
+ dim_partition_list = []
34
+ dim_size = len(self.op_data['input0'].shape)
35
+ dim_partition_list.append({})
36
+ dim_partition_list.extend(
37
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
38
+ dim_partition_list.extend(
39
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
40
+ dim_partition_list.extend(
41
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
42
+ dim_partition_list.extend(
43
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
44
+ strategies_vector = StrategiesVector(self)
45
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
46
+ for dim_partition_dict in dim_partition_list:
47
+ in0_partition_dict = dim_partition_dict
48
+ out_partition_dict = {}
49
+ for split_dim, mesh_dim in in0_partition_dict.items():
50
+ trans_dim = transpose_dims[split_dim]
51
+ out_partition_dict[trans_dim] = mesh_dim
52
+
53
+ dim_partition_dict_mapping = {
54
+ "input0": in0_partition_dict,
55
+ "output0": out_partition_dict,
56
+ }
57
+ if self.num_inputs == 2:
58
+ dim_partition_dict_mapping["input1"] = {}
59
+
60
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
61
+ dim_partition_dict_mapping, device_mesh)
62
+ if 0 == len(sharding_spec_mapping):
63
+ continue
64
+ name = '{} = <shuffle_transpose_only op> {}'.format(
65
+ sharding_spec_mapping['output0'].sharding_sequence,
66
+ sharding_spec_mapping['input0'].sharding_sequence)
67
+ sharding_strategy = self._get_sharding_strategy(
68
+ name=name,
69
+ sharding_spec_mapping=sharding_spec_mapping,
70
+ communication_action_mapping={})
71
+ strategies_vector.append(sharding_strategy)
72
+ return strategies_vector
73
+
74
+ def _find_reshape_partitions(self, input_shape, output_shape,
75
+ input_partition_dict):
76
+ len_input_shape, len_output_shape = len(input_shape), len(output_shape)
77
+ output_partition_dict = {}
78
+ i, j = 0, 0
79
+ while i < len_input_shape or j < len_output_shape:
80
+ if i < len_input_shape and input_shape[i] == 1:
81
+ i = i + 1
82
+ continue
83
+ if j < len_output_shape and output_shape[j] == 1:
84
+ j = j + 1
85
+ continue
86
+
87
+ if input_shape[i] == output_shape[j]:
88
+ if i in input_partition_dict:
89
+ output_partition_dict[j] = input_partition_dict[i]
90
+ # it keep the dimension, so need to keep the partition dims
91
+ i, j = i + 1, j + 1
92
+
93
+ elif input_shape[i] < output_shape[j]:
94
+ # we detect if the input dims are merged in the reshape dim
95
+ value = input_shape[i]
96
+ for ii in range(i + 1, len_input_shape):
97
+ value = value * input_shape[ii]
98
+ if value == output_shape[j]:
99
+ # it is merged, we set the output's merged dim partition as all inputs' dims
100
+ mesh_dim = []
101
+ for in_dim in range(i, ii + 1):
102
+ if in_dim in input_partition_dict:
103
+ mesh_dim = mesh_dim + input_partition_dict[
104
+ in_dim]
105
+ if len(mesh_dim) > 0:
106
+ output_partition_dict[j] = sorted(mesh_dim)
107
+ i, j = ii + 1, j + 1
108
+ break
109
+ else:
110
+ # we don't find the merged dimensions, the difference may from random reshape, we don't support it now
111
+ return {}, {}
112
+ else:
113
+ # we detect if the input dim is split into reshape dims
114
+ value = output_shape[j]
115
+ for jj in range(j + 1, len_output_shape):
116
+ value = value * output_shape[jj]
117
+ if value == input_shape[i]:
118
+ # it is split pattern
119
+ if i in input_partition_dict:
120
+ output_partition_dict[j] = input_partition_dict[i]
121
+ i, j = i + 1, jj + 1
122
+ break
123
+ else:
124
+ # we don't find the split dimensions, the difference may from random reshape
125
+ return {}, {}
126
+ return input_partition_dict, output_partition_dict
127
+
128
+ def _collect_reshape_strategies(self, device_mesh):
129
+ dim_partition_list = []
130
+ dim_size = len(self.op_data['input0'].shape)
131
+ dim_partition_list.append({})
132
+ dim_partition_list.extend(
133
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
134
+ dim_partition_list.extend(
135
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
136
+ dim_partition_list.extend(
137
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
138
+ dim_partition_list.extend(
139
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
140
+ strategies_vector = StrategiesVector(self)
141
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
142
+ for dim_partition_dict in dim_partition_list:
143
+ in0_partition_dict = dim_partition_dict
144
+ in0_partition_dict, out_partition_dict = self._find_reshape_partitions(
145
+ self.input_shape, self.reshape_dims, in0_partition_dict)
146
+ dim_partition_dict_mapping = {
147
+ "input0": in0_partition_dict,
148
+ "output0": out_partition_dict,
149
+ }
150
+ if self.num_inputs == 2:
151
+ dim_partition_dict_mapping["input1"] = {}
152
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
153
+ dim_partition_dict_mapping, device_mesh)
154
+ if 0 == len(sharding_spec_mapping):
155
+ continue
156
+ name = '{} = <shuffle_reshape op> {}'.format(
157
+ sharding_spec_mapping['output0'].sharding_sequence,
158
+ sharding_spec_mapping['input0'].sharding_sequence)
159
+ sharding_strategy = self._get_sharding_strategy(
160
+ name=name,
161
+ sharding_spec_mapping=sharding_spec_mapping,
162
+ communication_action_mapping={})
163
+ strategies_vector.append(sharding_strategy)
164
+ return strategies_vector
165
+
166
+ def _collect_identity_strategies(self, device_mesh):
167
+ dim_partition_list = []
168
+ dim_size = len(self.op_data['input0'].shape)
169
+ dim_partition_list.append({})
170
+ dim_partition_list.extend(
171
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
172
+ dim_partition_list.extend(
173
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
174
+ dim_partition_list.extend(
175
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
176
+ dim_partition_list.extend(
177
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
178
+ strategies_vector = StrategiesVector(self)
179
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
180
+ for dim_partition_dict in dim_partition_list:
181
+ in0_partition_dict = dim_partition_dict
182
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
183
+ dim_partition_dict_mapping = {
184
+ "input0": in0_partition_dict,
185
+ "output0": out_partition_dict,
186
+ }
187
+ if self.num_inputs == 2:
188
+ dim_partition_dict_mapping["input1"] = {}
189
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
190
+ dim_partition_dict_mapping, device_mesh)
191
+ if 0 == len(sharding_spec_mapping):
192
+ continue
193
+ name = '{} = <shuffle_identity op> {}'.format(
194
+ sharding_spec_mapping['output0'].sharding_sequence,
195
+ sharding_spec_mapping['input0'].sharding_sequence)
196
+ sharding_strategy = self._get_sharding_strategy(
197
+ name=name,
198
+ sharding_spec_mapping=sharding_spec_mapping,
199
+ communication_action_mapping={})
200
+ strategies_vector.append(sharding_strategy)
201
+ return strategies_vector
202
+
203
+ def _collect_strategies(self, device_mesh):
204
+ is_identify_list = (self.is_first_transepose_identity,
205
+ self.is_reshape_identity,
206
+ self.is_second_transepose_identity)
207
+ if is_identify_list == (True, True, True):
208
+ return self._collect_identity_strategies(device_mesh)
209
+ elif is_identify_list == (True, True, False):
210
+ return self._collect_transpose_strategies(
211
+ device_mesh, self.second_transpose_dims)
212
+ elif is_identify_list == (False, True, True):
213
+ return self._collect_transpose_strategies(device_mesh,
214
+ self.first_transpose_dims)
215
+ elif is_identify_list == (True, False, True):
216
+ return self._collect_reshape_strategies(device_mesh)
217
+ else:
218
+ assert False, f"Unsupported shuffle pattern now {is_identify_list}"
219
+
220
+ def _profile_sharding_cost(self, strategy, device_mesh):
221
+ updated_layer_attrs = {}
222
+ updated_input_values = {}
223
+ output_shape = strategy.sharding_specs[
224
+ 'output0'].get_sharded_shape_per_device()
225
+ self.layer.to_subclass()
226
+ second_transpose = self.layer.as_trt().second_transpose
227
+ self.layer.to_base_class()
228
+ reshape_dims = [*output_shape]
229
+ for i in range(len(output_shape)):
230
+ reshape_dims[second_transpose[i]] = output_shape[i]
231
+ if self.layer.num_inputs >= 2:
232
+ updated_input_values[1] = reshape_dims
233
+ else:
234
+ updated_layer_attrs['reshape_dims'] = reshape_dims
235
+ elapsed_time = self.node_runtime_profiler.runtime_profile(
236
+ self.layer, updated_layer_attrs, updated_input_values, strategy,
237
+ device_mesh)
238
+ return elapsed_time
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/slice_node.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from .node import Node
4
+ from .sharding_strategy import StrategiesVector
5
+
6
+
7
+ class Slice(Node):
8
+
9
+ def __init__(self, layer):
10
+ super().__init__(layer)
11
+ layer.to_subclass()
12
+ input_shape = self.get_input(0).shape
13
+ output_shape = self.get_output(0).shape
14
+ assert len(input_shape) == len(
15
+ output_shape
16
+ ), f'dims of input shape {input_shape} != dims of output shape {output_shape}'
17
+ if layer.num_inputs >= 2 and layer.get_input(1) is not None:
18
+ start = layer.get_input(1).value
19
+ else:
20
+ start = layer.as_trt().start
21
+ if layer.num_inputs >= 4 and layer.get_input(3) is not None:
22
+ stride = layer.get_input(3).value
23
+ else:
24
+ stride = layer.as_trt().stride
25
+ self.keep_partition_dims = [(input_shape[i] == output_shape[i]
26
+ and start[i] == 0 and stride[i] == 1)
27
+ for i in range(len(input_shape))]
28
+ layer.to_base_class()
29
+
30
+ def _update_memory_cost(self, strategies):
31
+ for strategy in strategies:
32
+ # for slice node, it input0's read = output0's write
33
+ inout_memory_footprint = strategy.sharding_specs[
34
+ 'output0'].get_sharded_size_per_device() * 2
35
+ strategy.inout_memory_footprint = inout_memory_footprint
36
+ strategy.peak_memory_footprint = (
37
+ strategy.sharding_specs['input0'].
38
+ get_max_sharded_size_per_device() + strategy.
39
+ sharding_specs['output0'].get_max_sharded_size_per_device())
40
+
41
+ def _collect_strategies(self, device_mesh):
42
+ dim_partition_list = []
43
+ dim_size = len(self.op_data['input0'].shape)
44
+ dim_partition_list.append({})
45
+ dim_partition_list.extend(
46
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
47
+ dim_partition_list.extend(
48
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
49
+ dim_partition_list.extend(
50
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
51
+ dim_partition_list.extend(
52
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
53
+ strategies_vector = StrategiesVector(self)
54
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
55
+ for dim_partition_dict in dim_partition_list:
56
+ for dim in range(len(self.keep_partition_dims)):
57
+ if (not self.keep_partition_dims[dim]
58
+ ) and dim in dim_partition_dict:
59
+ dim_partition_dict.pop(dim)
60
+
61
+ in0_partition_dict = dim_partition_dict
62
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
63
+ dim_partition_dict_mapping = {
64
+ "input0": in0_partition_dict,
65
+ "output0": out_partition_dict,
66
+ }
67
+ for i in range(1, self.num_inputs):
68
+ if self.predecessor_nodes[i]:
69
+ dim_partition_dict_mapping[f"input{i}"] = {}
70
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
71
+ dim_partition_dict_mapping, device_mesh)
72
+ if 0 == len(sharding_spec_mapping):
73
+ continue
74
+ name = '{} = {} <slice op> '.format(
75
+ sharding_spec_mapping['output0'].sharding_sequence,
76
+ sharding_spec_mapping['input0'].sharding_sequence)
77
+ for i in range(1, self.num_inputs):
78
+ if self.predecessor_nodes[i]:
79
+ name = name + str(
80
+ sharding_spec_mapping[f'input{i}'].sharding_sequence)
81
+ sharding_strategy = self._get_sharding_strategy(
82
+ name=name,
83
+ sharding_spec_mapping=sharding_spec_mapping,
84
+ communication_action_mapping={})
85
+ strategies_vector.append(sharding_strategy)
86
+ return strategies_vector
87
+
88
+ def _profile_sharding_cost(self, strategy, device_mesh):
89
+ updated_layer_attrs = {}
90
+ updated_input_values = {}
91
+ shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
92
+ )
93
+ if self.layer.num_inputs >= 3 and self.layer.get_input(2) is not None:
94
+ updated_input_values[2] = shape
95
+ else:
96
+ updated_layer_attrs['shape'] = shape
97
+ elapsed_time = self.node_runtime_profiler.runtime_profile(
98
+ self.layer, updated_layer_attrs, updated_input_values, strategy,
99
+ device_mesh)
100
+ return elapsed_time
lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/tensor_parallel/softmax_node.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from tensorrt_llm._utils import trt_axes_to_dim
4
+
5
+ from .node import Node
6
+ from .sharding_strategy import StrategiesVector
7
+
8
+
9
+ class SoftMax(Node):
10
+
11
+ def __init__(self, layer):
12
+ super().__init__(layer)
13
+ layer.to_subclass()
14
+ self.softmax_dim = trt_axes_to_dim(layer.as_trt().axes)[0]
15
+ layer.to_base_class()
16
+
17
+ def _collect_strategies(self, device_mesh):
18
+ dim_partition_list = []
19
+ dim_size = len(self.op_data['input0'].shape)
20
+ dim_partition_list.append({})
21
+ dim_partition_list.extend(
22
+ self._enumerate_all_possible_1d_sharding([0], dim_size))
23
+ dim_partition_list.extend(
24
+ self._enumerate_all_possible_1d_sharding([1], dim_size))
25
+ dim_partition_list.extend(
26
+ self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
27
+ dim_partition_list.extend(
28
+ self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
29
+ strategies_vector = StrategiesVector(self)
30
+ # dim_partition_dict can be the same as previous node if solver's time is a problem
31
+ for dim_partition_dict in dim_partition_list:
32
+ if self.softmax_dim in dim_partition_dict:
33
+ dim_partition_dict.pop(self.softmax_dim)
34
+
35
+ in0_partition_dict = dim_partition_dict
36
+ out_partition_dict = copy.deepcopy(dim_partition_dict)
37
+ dim_partition_dict_mapping = {
38
+ "input0": in0_partition_dict,
39
+ "output0": out_partition_dict,
40
+ }
41
+ sharding_spec_mapping = self._to_sharding_spec_mapping(
42
+ dim_partition_dict_mapping, device_mesh)
43
+ if 0 == len(sharding_spec_mapping):
44
+ continue
45
+ name = '{} = <softmax along dim {}> {}'.format(
46
+ sharding_spec_mapping['output0'].sharding_sequence,
47
+ self.softmax_dim,
48
+ sharding_spec_mapping['input0'].sharding_sequence)
49
+ sharding_strategy = self._get_sharding_strategy(
50
+ name=name,
51
+ sharding_spec_mapping=sharding_spec_mapping,
52
+ communication_action_mapping={})
53
+ strategies_vector.append(sharding_strategy)
54
+ return strategies_vector