Add files using upload-large-folder tool
Browse files- logs/none_99omtdbz/attempt_0/0/stderr.log +0 -0
- logs/none_99omtdbz/attempt_0/2/stderr.log +0 -0
- logs/none_99omtdbz/attempt_0/5/stdout.log +0 -0
- logs/none_99omtdbz/attempt_0/6/stdout.log +0 -0
- profile_trace/iteration_17920/rank1_trace.json +0 -0
- profile_trace/iteration_17920/rank5_trace.json +0 -0
- profile_trace/iteration_21504/rank2_trace.json +0 -0
- profile_trace/iteration_24576/rank1_trace.json +0 -0
- profile_trace/iteration_24576/rank7_trace.json +0 -0
- profile_trace/iteration_33792/rank1_trace.json +0 -0
- profile_trace/iteration_39936/rank2_trace.json +0 -0
- profile_trace/iteration_39936/rank4_trace.json +0 -0
- profile_trace/iteration_512/rank0_trace.json +0 -0
- pyproject.toml +43 -0
- torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
- torchtitan/components/__pycache__/ft.cpython-312.pyc +0 -0
- torchtitan/components/float8.py +150 -0
- torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc +0 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
- torchtitan/experiments/flux/flux_argparser.py +42 -0
- torchtitan/experiments/flux/loss.py +27 -0
- torchtitan/experiments/flux/requirements.txt +2 -0
- torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
- torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
- torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/model.py +466 -0
- torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py +536 -0
- torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
- torchtitan/experiments/multimodal/mm_dataset.py +268 -0
- torchtitan/experiments/multimodal/requirements.txt +1 -0
- torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
- torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
- torchtitan/experiments/multimodal/utils.py +437 -0
- torchtitan/experiments/simple_fsdp/README.md +40 -0
- torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/tests/test_numerics.py +128 -0
- torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/models/llama3/__pycache__/model.cpython-312.pyc +0 -0
- torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
- torchtitan/models/llama3/parallelize_llama.py +398 -0
- torchtitan/models/llama3/train_configs/llama3_405b.toml +63 -0
- torchtitan/tools/__pycache__/utils.cpython-312.pyc +0 -0
logs/none_99omtdbz/attempt_0/0/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs/none_99omtdbz/attempt_0/2/stderr.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs/none_99omtdbz/attempt_0/5/stdout.log
ADDED
File without changes
|
logs/none_99omtdbz/attempt_0/6/stdout.log
ADDED
File without changes
|
profile_trace/iteration_17920/rank1_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_17920/rank5_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_21504/rank2_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_24576/rank1_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_24576/rank7_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_33792/rank1_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_39936/rank2_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_39936/rank4_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
profile_trace/iteration_512/rank0_trace.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "flame"
|
3 |
+
dynamic = ["version"]
|
4 |
+
description = "A minimal training framework for scaling FLA models"
|
5 |
+
readme = "README.md"
|
6 |
+
authors = [
|
7 |
+
{ name = "Songlin Yang", email = "yangsl66@mit.edu" },
|
8 |
+
{ name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
|
9 |
+
]
|
10 |
+
license = { file = "LICENSE" }
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"License :: OSI Approved :: MIT License",
|
14 |
+
"Operating System :: OS Independent",
|
15 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
16 |
+
]
|
17 |
+
requires-python = ">=3.10"
|
18 |
+
dependencies = [
|
19 |
+
'torch==2.6',
|
20 |
+
'torchdata',
|
21 |
+
'transformers==4.51.3',
|
22 |
+
'triton>=3.0',
|
23 |
+
'datasets>=3.3.0',
|
24 |
+
'einops',
|
25 |
+
'ninja',
|
26 |
+
'wandb',
|
27 |
+
'tiktoken',
|
28 |
+
'tensorboard',
|
29 |
+
'python-dotenv'
|
30 |
+
]
|
31 |
+
|
32 |
+
[project.optional-dependencies]
|
33 |
+
dev = ["pytest"]
|
34 |
+
|
35 |
+
[project.urls]
|
36 |
+
Homepage = "https://github.com/fla-org/flame"
|
37 |
+
|
38 |
+
[build-system]
|
39 |
+
requires = ["setuptools>=45", "wheel", "ninja", "torch"]
|
40 |
+
|
41 |
+
[tool.isort]
|
42 |
+
line_length = 127
|
43 |
+
multi_line_output = 3
|
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc
ADDED
Binary file (33.1 kB). View file
|
|
torchtitan/components/__pycache__/dataloader.cpython-312.pyc
ADDED
Binary file (3.78 kB). View file
|
|
torchtitan/components/__pycache__/ft.cpython-312.pyc
ADDED
Binary file (6.75 kB). View file
|
|
torchtitan/components/float8.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# [Note] Getting the 'torchao' package:
|
8 |
+
# This script requires the 'torchao' package to function correctly.
|
9 |
+
# Please ensure you have this package installed from the appropriate repository.
|
10 |
+
# You can obtain it from https://github.com/pytorch/ao by following the
|
11 |
+
# installation instructions.
|
12 |
+
|
13 |
+
# Note: Performance
|
14 |
+
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from torchtitan.config_manager import JobConfig
|
20 |
+
from torchtitan.distributed import ParallelDims
|
21 |
+
from torchtitan.protocols.model_converter import (
|
22 |
+
ModelConverter,
|
23 |
+
register_model_converter,
|
24 |
+
)
|
25 |
+
from torchtitan.tools.logging import logger
|
26 |
+
|
27 |
+
|
28 |
+
def _is_sm89_or_later():
|
29 |
+
# Float8 is only supported on SM89 or later (H100+ GPUs)
|
30 |
+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
|
31 |
+
|
32 |
+
|
33 |
+
class Float8Converter(ModelConverter):
|
34 |
+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
|
35 |
+
self.enabled = False
|
36 |
+
|
37 |
+
float8_config = job_config.float8
|
38 |
+
if not _is_sm89_or_later():
|
39 |
+
logger.warning(
|
40 |
+
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
|
41 |
+
)
|
42 |
+
return
|
43 |
+
try:
|
44 |
+
from torchao.float8 import Float8LinearConfig
|
45 |
+
except ImportError as e:
|
46 |
+
raise ImportError(
|
47 |
+
"torchao is not installed. Please install it to use float8 linear layers."
|
48 |
+
) from e
|
49 |
+
|
50 |
+
if float8_config.recipe_name is not None and not hasattr(
|
51 |
+
Float8LinearConfig, "from_recipe_name"
|
52 |
+
):
|
53 |
+
logger.warning(
|
54 |
+
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
|
55 |
+
"is too old, please install torchao v0.9.0 or later and try again",
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
self.enabled = True
|
60 |
+
self.filter_fqns = float8_config.filter_fqns
|
61 |
+
|
62 |
+
if float8_config.recipe_name is not None:
|
63 |
+
assert (
|
64 |
+
not float8_config.enable_fsdp_float8_all_gather
|
65 |
+
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
|
66 |
+
assert (
|
67 |
+
not float8_config.force_recompute_fp8_weight_in_bwd
|
68 |
+
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
|
69 |
+
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
|
70 |
+
self.precompute_scale = False
|
71 |
+
logger.info(
|
72 |
+
f"Float8 training active with recipe {float8_config.recipe_name}"
|
73 |
+
)
|
74 |
+
|
75 |
+
else:
|
76 |
+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
|
77 |
+
enable_fsdp_float8_all_gather = (
|
78 |
+
parallel_dims.dp_shard_enabled
|
79 |
+
and float8_config.enable_fsdp_float8_all_gather
|
80 |
+
)
|
81 |
+
self.config = Float8LinearConfig(
|
82 |
+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
83 |
+
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
|
84 |
+
)
|
85 |
+
# for precompute_float8_dynamic_scale_for_fsdp
|
86 |
+
self.precompute_scale = (
|
87 |
+
enable_fsdp_float8_all_gather
|
88 |
+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
|
89 |
+
)
|
90 |
+
logger.info("Float8 tensorwise scaled training active")
|
91 |
+
|
92 |
+
def convert(self, model: nn.Module):
|
93 |
+
return self.convert_to_float8_training(model)
|
94 |
+
|
95 |
+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
|
96 |
+
return self.precompute_float8_dynamic_scale_for_fsdp(model)
|
97 |
+
|
98 |
+
def convert_to_float8_training(self, model: nn.Module):
|
99 |
+
"""
|
100 |
+
This function converts the linear layers of `model` to `Float8Linear`.
|
101 |
+
Note that today, only dynamic tensor scaling (the default) is supported.
|
102 |
+
This will mutate the model inplace.
|
103 |
+
"""
|
104 |
+
if not self.enabled:
|
105 |
+
return
|
106 |
+
|
107 |
+
from torchao.float8 import convert_to_float8_training
|
108 |
+
|
109 |
+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
|
110 |
+
convert_to_float8_training(
|
111 |
+
model,
|
112 |
+
config=self.config,
|
113 |
+
module_filter_fn=self._module_filter_fn,
|
114 |
+
)
|
115 |
+
logger.info(
|
116 |
+
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
|
117 |
+
f"{self.config.enable_fsdp_float8_all_gather}"
|
118 |
+
)
|
119 |
+
|
120 |
+
def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
|
121 |
+
if not isinstance(mod, nn.Linear):
|
122 |
+
return False
|
123 |
+
|
124 |
+
# All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
|
125 |
+
dims_multiples_of_16 = (
|
126 |
+
mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
|
127 |
+
)
|
128 |
+
|
129 |
+
# If the fqn matches any filtered fqn, then we should not convert this module.
|
130 |
+
is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
|
131 |
+
|
132 |
+
return dims_multiples_of_16 and not is_filtered_fqn
|
133 |
+
|
134 |
+
def precompute_float8_dynamic_scale_for_fsdp(
|
135 |
+
self, model: nn.Module | list[nn.Module]
|
136 |
+
):
|
137 |
+
if not self.enabled:
|
138 |
+
return
|
139 |
+
|
140 |
+
if not self.precompute_scale:
|
141 |
+
return
|
142 |
+
|
143 |
+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
|
144 |
+
|
145 |
+
models = [model] if isinstance(model, nn.Module) else model
|
146 |
+
for m in models:
|
147 |
+
precompute_float8_dynamic_scale_for_fsdp(m)
|
148 |
+
|
149 |
+
|
150 |
+
register_model_converter(Float8Converter, "float8")
|
torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc
ADDED
Binary file (6.1 kB). View file
|
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
from .triton_utils import get_flat_bid, get_flat_tid
|
11 |
+
|
12 |
+
|
13 |
+
@triton.jit
|
14 |
+
def send_signal(addrs, sem: tl.constexpr):
|
15 |
+
if sem == "relaxed":
|
16 |
+
tl.inline_asm_elementwise(
|
17 |
+
"""
|
18 |
+
{
|
19 |
+
.reg .u32 %tmp32_<1>;
|
20 |
+
.reg .pred %p<1>;
|
21 |
+
|
22 |
+
send_signal:
|
23 |
+
atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
24 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
25 |
+
@!%p0 bra send_signal;
|
26 |
+
}
|
27 |
+
""",
|
28 |
+
"=r, l",
|
29 |
+
[addrs],
|
30 |
+
dtype=tl.int32,
|
31 |
+
is_pure=False,
|
32 |
+
pack=1,
|
33 |
+
)
|
34 |
+
elif sem == "acq_rel":
|
35 |
+
tl.inline_asm_elementwise(
|
36 |
+
"""
|
37 |
+
{
|
38 |
+
.reg .u32 %tmp32_<1>;
|
39 |
+
.reg .pred %p<1>;
|
40 |
+
|
41 |
+
send_signal:
|
42 |
+
atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
|
43 |
+
setp.eq.u32 %p0, %tmp32_0, 0;
|
44 |
+
@!%p0 bra send_signal;
|
45 |
+
}
|
46 |
+
""",
|
47 |
+
"=r, l",
|
48 |
+
[addrs],
|
49 |
+
dtype=tl.int32,
|
50 |
+
is_pure=False,
|
51 |
+
pack=1,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
55 |
+
|
56 |
+
|
57 |
+
@triton.jit
|
58 |
+
def wait_signal(addrs, sem: tl.constexpr):
|
59 |
+
if sem == "relaxed":
|
60 |
+
tl.inline_asm_elementwise(
|
61 |
+
"""
|
62 |
+
{
|
63 |
+
.reg .u32 %tmp32_<1>;
|
64 |
+
.reg .pred %p<1>;
|
65 |
+
|
66 |
+
wait_signal:
|
67 |
+
atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
|
68 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
69 |
+
@!%p0 bra wait_signal;
|
70 |
+
}
|
71 |
+
""",
|
72 |
+
"=r, l",
|
73 |
+
[addrs],
|
74 |
+
dtype=tl.int32,
|
75 |
+
is_pure=False,
|
76 |
+
pack=1,
|
77 |
+
)
|
78 |
+
elif sem == "acq_rel":
|
79 |
+
tl.inline_asm_elementwise(
|
80 |
+
"""
|
81 |
+
{
|
82 |
+
.reg .u32 %tmp32_<1>;
|
83 |
+
.reg .pred %p<1>;
|
84 |
+
|
85 |
+
wait_signal:
|
86 |
+
atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
|
87 |
+
setp.eq.u32 %p0, %tmp32_0, 1;
|
88 |
+
@!%p0 bra wait_signal;
|
89 |
+
}
|
90 |
+
""",
|
91 |
+
"=r, l",
|
92 |
+
[addrs],
|
93 |
+
dtype=tl.int32,
|
94 |
+
is_pure=False,
|
95 |
+
pack=1,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
raise RuntimeError(f"Unrecognized sem: {sem}")
|
99 |
+
|
100 |
+
|
101 |
+
@triton.jit
|
102 |
+
def blockwise_barrier(
|
103 |
+
signal_pad_ptrs,
|
104 |
+
block_id,
|
105 |
+
rank: tl.constexpr,
|
106 |
+
world_size: tl.constexpr,
|
107 |
+
sem: tl.constexpr,
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
Synchronizes blocks with matching block_id across participating devices.
|
111 |
+
|
112 |
+
Note: the function itself is not a system level barrier/fence. It is a
|
113 |
+
building block for expressing different synchronization patterns.
|
114 |
+
|
115 |
+
Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
116 |
+
kernels across all devices are visible to the current kernel:
|
117 |
+
|
118 |
+
blockwise_barrier(..., sem="relaxed")
|
119 |
+
sync_threads()
|
120 |
+
|
121 |
+
Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
122 |
+
block are visible to all remote blocks with matching blockIdx:
|
123 |
+
|
124 |
+
sync_threads()
|
125 |
+
blockwise_barrier(..., sem="acq_rel")
|
126 |
+
sync_threads()
|
127 |
+
|
128 |
+
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
129 |
+
for writing by subsequent kernels across all devices.
|
130 |
+
|
131 |
+
sync_threads()
|
132 |
+
blockwise_barrier(..., sem="relaxed")
|
133 |
+
|
134 |
+
CUDA graph friendliness:
|
135 |
+
|
136 |
+
This barrier operates through atomic operations on a zero-filled signal
|
137 |
+
pad, which resets to a zero-filled state after each successful
|
138 |
+
synchronization. This design eliminates the need for incrementing a
|
139 |
+
flag from host.
|
140 |
+
"""
|
141 |
+
if block_id is None:
|
142 |
+
block_id = get_flat_bid()
|
143 |
+
flat_tid = get_flat_tid()
|
144 |
+
|
145 |
+
remote_ranks = tl.arange(0, world_size)
|
146 |
+
signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
|
147 |
+
remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
|
148 |
+
tl.pointer_type(tl.uint32)
|
149 |
+
)
|
150 |
+
send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
|
151 |
+
|
152 |
+
local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
|
153 |
+
tl.pointer_type(tl.uint32)
|
154 |
+
)
|
155 |
+
wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
|
156 |
+
|
157 |
+
if flat_tid < world_size:
|
158 |
+
send_signal(send_addrs, sem)
|
159 |
+
wait_signal(wait_addrs, sem)
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from .triton_barrier import blockwise_barrier
|
14 |
+
from .triton_utils import sync_threads
|
15 |
+
|
16 |
+
|
17 |
+
@triton.jit
|
18 |
+
def _exchange_row_offsets(
|
19 |
+
split_sizes_ptrs,
|
20 |
+
rank: tl.constexpr,
|
21 |
+
world_size: tl.constexpr,
|
22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
23 |
+
):
|
24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
25 |
+
|
26 |
+
# split_sizes_ptr for all ranks
|
27 |
+
# All these vector stacks into split_sizes_matrix
|
28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
29 |
+
|
30 |
+
# split_sizes_matrix[remote_rank, :]
|
31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
32 |
+
tl.pointer_type(tl.int64)
|
33 |
+
)
|
34 |
+
|
35 |
+
offsets_ = tl.arange(0, world_size)
|
36 |
+
input_split_sizes = tl.load(
|
37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
38 |
+
)
|
39 |
+
|
40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
42 |
+
|
43 |
+
# split_sizes_matrix[:, rank]
|
44 |
+
output_split_sizes_ptrs = (
|
45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
46 |
+
)
|
47 |
+
output_split_sizes = tl.load(
|
48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
49 |
+
)
|
50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
51 |
+
|
52 |
+
return input_row_offset, output_row_offset, num_rows
|
53 |
+
|
54 |
+
|
55 |
+
@triton.jit
|
56 |
+
def on_device_all_to_all_v_kernel(
|
57 |
+
output_ptr,
|
58 |
+
output_splits_ptr,
|
59 |
+
input_ptrs,
|
60 |
+
input_splits_ptr,
|
61 |
+
signal_pad_ptrs,
|
62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
63 |
+
rank: tl.constexpr,
|
64 |
+
world_size: tl.constexpr,
|
65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
66 |
+
UNROLL_FACTOR: tl.constexpr,
|
67 |
+
BLOCK_SIZE: tl.constexpr,
|
68 |
+
):
|
69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
70 |
+
sync_threads()
|
71 |
+
|
72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
74 |
+
|
75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
77 |
+
)
|
78 |
+
|
79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
80 |
+
if block_offset == 0:
|
81 |
+
# Update output_splits
|
82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
83 |
+
|
84 |
+
input_ptr = (
|
85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
86 |
+
tl.pointer_type(tl.bfloat16)
|
87 |
+
)
|
88 |
+
+ input_row_offset * dim
|
89 |
+
)
|
90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
91 |
+
|
92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
95 |
+
)
|
96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
97 |
+
offset = numel_per_rank * block_offset
|
98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
99 |
+
|
100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
102 |
+
datas = []
|
103 |
+
for j in tl.range(
|
104 |
+
i,
|
105 |
+
i + outer_loop_step,
|
106 |
+
BLOCK_SIZE,
|
107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
108 |
+
):
|
109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
110 |
+
data = tl.load(input_ptr + offsets)
|
111 |
+
tl.store(output_ptr + offsets, data)
|
112 |
+
|
113 |
+
offset += unroll_region_size
|
114 |
+
while offset < end:
|
115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
116 |
+
mask = offsets < num_rows * dim
|
117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
119 |
+
offset += BLOCK_SIZE
|
120 |
+
|
121 |
+
sync_threads()
|
122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
def _on_device_all_to_all_v(
|
127 |
+
output: torch.Tensor,
|
128 |
+
output_splits: torch.Tensor,
|
129 |
+
input: torch.Tensor,
|
130 |
+
input_splits: torch.Tensor,
|
131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
133 |
+
UNROLL_FACTOR: int = 8,
|
134 |
+
BLOCK_SIZE: int = 16384,
|
135 |
+
):
|
136 |
+
assert output.dim() == 2, f"{output.shape}"
|
137 |
+
assert input.dim() == 2, f"{input.shape}"
|
138 |
+
assert output.shape[1] == input.shape[1]
|
139 |
+
|
140 |
+
dim = output.shape[1]
|
141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
143 |
+
|
144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
146 |
+
output,
|
147 |
+
output_splits,
|
148 |
+
input_hdl.buffer_ptrs_dev,
|
149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
150 |
+
input_hdl.signal_pad_ptrs_dev,
|
151 |
+
dim=dim,
|
152 |
+
rank=input_hdl.rank,
|
153 |
+
world_size=input_hdl.world_size,
|
154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
157 |
+
num_warps=16,
|
158 |
+
)
|
159 |
+
# log_triton_kernel(kernel)
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
164 |
+
# A symmetric memory holding the grad_output during backward
|
165 |
+
grad_output_buf = None
|
166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
167 |
+
splits_buf = None
|
168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
169 |
+
max_output_len = None
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def forward(
|
173 |
+
ctx,
|
174 |
+
input: torch.Tensor,
|
175 |
+
input_splits: torch.Tensor,
|
176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
input: input tensor with data for all ranks concatenated.
|
181 |
+
input_splits: input splits of shape (group.world_size,)
|
182 |
+
group: process group to scope the collective.
|
183 |
+
"""
|
184 |
+
# Initialize input splits buffer (one time only)
|
185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
187 |
+
*input_splits.shape,
|
188 |
+
dtype=input_splits.dtype,
|
189 |
+
device=input_splits.device,
|
190 |
+
)
|
191 |
+
|
192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
193 |
+
raise RuntimeError(
|
194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
195 |
+
)
|
196 |
+
|
197 |
+
# Allocate output buffer
|
198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
199 |
+
# Allocate output splits tensor
|
200 |
+
output_splits = torch.empty_like(input_splits)
|
201 |
+
# Copy input splits to the buffer
|
202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
203 |
+
|
204 |
+
# Shuffle input to output
|
205 |
+
_on_device_all_to_all_v(
|
206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
207 |
+
)
|
208 |
+
|
209 |
+
# Output splits in forward is the input splits in backward
|
210 |
+
ctx.save_for_backward(output_splits)
|
211 |
+
ctx.group = group
|
212 |
+
ctx.input_shape = input.shape
|
213 |
+
return output, output_splits
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def backward(ctx, grad_output, grad_splits):
|
217 |
+
"""
|
218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
219 |
+
Args:
|
220 |
+
`grad_output`: output's gradients passed from the downstream.
|
221 |
+
`grad_splits`: unused.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# Initialize grad_output buffer (one time only)
|
225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
226 |
+
assert (
|
227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
228 |
+
), "`max_output_len` not set"
|
229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
230 |
+
OnDeviceAllToAllV.max_output_len,
|
231 |
+
*grad_output.shape[1:],
|
232 |
+
dtype=grad_output.dtype,
|
233 |
+
device=grad_output.device,
|
234 |
+
)
|
235 |
+
|
236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
237 |
+
# our symm_mem buffer?
|
238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
239 |
+
grad_output
|
240 |
+
)
|
241 |
+
|
242 |
+
# Size info
|
243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
247 |
+
|
248 |
+
# Shuffle gradients back to the input
|
249 |
+
_on_device_all_to_all_v(
|
250 |
+
grad_input,
|
251 |
+
grad_input_splits,
|
252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
253 |
+
OnDeviceAllToAllV.splits_buf,
|
254 |
+
group=ctx.group,
|
255 |
+
)
|
256 |
+
return grad_input, None, None
|
257 |
+
|
258 |
+
|
259 |
+
# Alias
|
260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
|
11 |
+
@triton.jit
|
12 |
+
def get_tid():
|
13 |
+
return tl.inline_asm_elementwise(
|
14 |
+
"""
|
15 |
+
mov.u32 $0, %tid.x;
|
16 |
+
mov.u32 $1, %tid.y;
|
17 |
+
mov.u32 $2, %tid.z;
|
18 |
+
""",
|
19 |
+
"=r,=r,=r",
|
20 |
+
[],
|
21 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
22 |
+
is_pure=True,
|
23 |
+
pack=1,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@triton.jit
|
28 |
+
def get_ntid():
|
29 |
+
return tl.inline_asm_elementwise(
|
30 |
+
"""
|
31 |
+
mov.u32 $0, %ntid.x;
|
32 |
+
mov.u32 $1, %ntid.y;
|
33 |
+
mov.u32 $2, %ntid.z;
|
34 |
+
""",
|
35 |
+
"=r,=r,=r",
|
36 |
+
[],
|
37 |
+
dtype=(tl.uint32, tl.uint32, tl.uint32),
|
38 |
+
is_pure=True,
|
39 |
+
pack=1,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
@triton.jit
|
44 |
+
def get_flat_tid():
|
45 |
+
tid_x, tid_y, tid_z = get_tid()
|
46 |
+
ntid_x, ntid_y, _ = get_ntid()
|
47 |
+
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
|
48 |
+
|
49 |
+
|
50 |
+
@triton.jit
|
51 |
+
def get_flat_bid():
|
52 |
+
return (
|
53 |
+
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
|
54 |
+
+ tl.program_id(1) * tl.num_programs(0)
|
55 |
+
+ tl.program_id(0)
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
@triton.jit
|
60 |
+
def sync_threads():
|
61 |
+
tl.inline_asm_elementwise(
|
62 |
+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
|
63 |
+
)
|
torchtitan/experiments/flux/flux_argparser.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def extend_parser(parser: argparse.ArgumentParser) -> None:
|
13 |
+
parser.add_argument(
|
14 |
+
"--training.guidance",
|
15 |
+
type=float,
|
16 |
+
default=3.5,
|
17 |
+
help="guidance value used for guidance distillation",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--encoder.t5_encoder",
|
21 |
+
type=str,
|
22 |
+
default="google/t5-v1_1-small",
|
23 |
+
help="T5 encoder to use, HuggingFace model name.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--encoder.clip_encoder",
|
27 |
+
type=str,
|
28 |
+
default="openai/clip-vit-large-patch14",
|
29 |
+
help="Clip encoder to use, HuggingFace model name.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--encoder.encoder_dtype",
|
33 |
+
type=torch.dtype,
|
34 |
+
default=torch.bfloat16,
|
35 |
+
help="Which dtype to load for autoencoder. ",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--encoder.max_t5_encoding_len",
|
39 |
+
type=int,
|
40 |
+
default=512,
|
41 |
+
help="Maximum length of the T5 encoding.",
|
42 |
+
)
|
torchtitan/experiments/flux/loss.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, TypeAlias
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torchtitan.config_manager import JobConfig
|
12 |
+
from torchtitan.tools.logging import logger
|
13 |
+
|
14 |
+
LossFunction: TypeAlias = Callable[..., torch.Tensor]
|
15 |
+
|
16 |
+
|
17 |
+
def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
18 |
+
"""Common MSE loss function for Transformer models training."""
|
19 |
+
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
|
20 |
+
|
21 |
+
|
22 |
+
def build_mse_loss(job_config: JobConfig):
|
23 |
+
loss_fn = mse_loss
|
24 |
+
if job_config.training.compile:
|
25 |
+
logger.info("Compiling the loss function with torch.compile")
|
26 |
+
loss_fn = torch.compile(loss_fn)
|
27 |
+
return loss_fn
|
torchtitan/experiments/flux/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
einops
|
torchtitan/experiments/flux/tests/test_flux_dataloader.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from torchtitan.config_manager import JobConfig
|
10 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
11 |
+
from torchtitan.tools.profiling import (
|
12 |
+
maybe_enable_memory_snapshot,
|
13 |
+
maybe_enable_profiling,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TestFluxDataLoader:
|
18 |
+
def test_flux_dataloader(self):
|
19 |
+
dataset_name = "cc12m"
|
20 |
+
batch_size = 32
|
21 |
+
world_size = 4
|
22 |
+
rank = 0
|
23 |
+
|
24 |
+
num_steps = 10
|
25 |
+
|
26 |
+
path = "torchtitan.experiments.flux.flux_argparser"
|
27 |
+
sys.argv.append(f"--experimental.custom_args_module={path}")
|
28 |
+
config = JobConfig()
|
29 |
+
config.maybe_add_custom_args()
|
30 |
+
config.parse_args(
|
31 |
+
[
|
32 |
+
# Profiling options
|
33 |
+
# "--profiling.enable_profiling",
|
34 |
+
# "--profiling.profile_freq",
|
35 |
+
# "5",
|
36 |
+
# "--profiling.enable_memory_snapshot",
|
37 |
+
# "--profiling.save_memory_snapshot_folder",
|
38 |
+
# "memory_snapshot_flux",
|
39 |
+
"--training.dataset",
|
40 |
+
dataset_name,
|
41 |
+
"--training.batch_size",
|
42 |
+
str(batch_size),
|
43 |
+
"--encoder.t5_encoder",
|
44 |
+
"google/t5-v1_1-small",
|
45 |
+
"--encoder.clip_encoder",
|
46 |
+
"openai/clip-vit-large-patch14",
|
47 |
+
"--encoder.max_t5_encoding_len",
|
48 |
+
"512",
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
with maybe_enable_profiling(
|
53 |
+
config, global_step=0
|
54 |
+
) as torch_profiler, maybe_enable_memory_snapshot(
|
55 |
+
config, global_step=0
|
56 |
+
) as memory_profiler:
|
57 |
+
dl = self._build_dataloader(
|
58 |
+
config,
|
59 |
+
world_size,
|
60 |
+
rank,
|
61 |
+
)
|
62 |
+
dl = iter(dl)
|
63 |
+
|
64 |
+
for i in range(0, num_steps):
|
65 |
+
input_data, labels = next(dl)
|
66 |
+
print(f"Step {i} image size: {labels.shape}")
|
67 |
+
if torch_profiler:
|
68 |
+
torch_profiler.step()
|
69 |
+
if memory_profiler:
|
70 |
+
memory_profiler.step()
|
71 |
+
|
72 |
+
print(len(input_data["clip_tokens"]))
|
73 |
+
for k, v in input_data.items():
|
74 |
+
print(f"Step {i} {k} value: {type(v), v.shape}")
|
75 |
+
|
76 |
+
assert len(input_data) == 2 # (clip_encodings, t5_encodings)
|
77 |
+
assert labels.shape == (batch_size, 3, 256, 256)
|
78 |
+
# assert input_data["clip_tokens"].shape[0] == batch_size
|
79 |
+
# assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
|
80 |
+
|
81 |
+
if torch_profiler:
|
82 |
+
torch_profiler.step()
|
83 |
+
if memory_profiler:
|
84 |
+
memory_profiler.step(exit_ctx=True)
|
85 |
+
|
86 |
+
def test_preprocess(self):
|
87 |
+
# TODO
|
88 |
+
pass
|
89 |
+
|
90 |
+
def _build_dataloader(
|
91 |
+
self,
|
92 |
+
job_config,
|
93 |
+
world_size,
|
94 |
+
rank,
|
95 |
+
):
|
96 |
+
|
97 |
+
return build_flux_dataloader(
|
98 |
+
dp_world_size=world_size,
|
99 |
+
dp_rank=rank,
|
100 |
+
job_config=job_config,
|
101 |
+
tokenizer=None,
|
102 |
+
infinite=False,
|
103 |
+
)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from PIL import ExifTags, Image
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
20 |
+
|
21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
22 |
+
AutoEncoder,
|
23 |
+
AutoEncoderParams,
|
24 |
+
load_ae,
|
25 |
+
)
|
26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
27 |
+
|
28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
29 |
+
from torchtitan.experiments.flux.utils import (
|
30 |
+
create_position_encoding_for_latents,
|
31 |
+
generate_noise_latent,
|
32 |
+
pack_latents,
|
33 |
+
preprocess_flux_data,
|
34 |
+
unpack_latents,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lin_function(
|
43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
44 |
+
) -> Callable[[float], float]:
|
45 |
+
m = (y2 - y1) / (x2 - x1)
|
46 |
+
b = y1 - m * x1
|
47 |
+
return lambda x: m * x + b
|
48 |
+
|
49 |
+
|
50 |
+
def get_schedule(
|
51 |
+
num_steps: int,
|
52 |
+
image_seq_len: int,
|
53 |
+
base_shift: float = 0.5,
|
54 |
+
max_shift: float = 1.15,
|
55 |
+
shift: bool = True,
|
56 |
+
) -> list[float]:
|
57 |
+
# extra step for zero
|
58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
59 |
+
|
60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
61 |
+
if shift:
|
62 |
+
# estimate mu based on linear estimation between two points
|
63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
65 |
+
|
66 |
+
return timesteps.tolist()
|
67 |
+
|
68 |
+
|
69 |
+
class TestGenerateImage:
|
70 |
+
def test_generate_image(self):
|
71 |
+
"""
|
72 |
+
Run a forward pass of flux model to generate an image.
|
73 |
+
"""
|
74 |
+
name = "flux-dev"
|
75 |
+
img_width = 512
|
76 |
+
img_height = 512
|
77 |
+
seed = None
|
78 |
+
prompt = (
|
79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
81 |
+
)
|
82 |
+
device = "cuda"
|
83 |
+
num_steps = None
|
84 |
+
loop = False
|
85 |
+
guidance = 3.5
|
86 |
+
output_dir = "output"
|
87 |
+
add_sampling_metadata = True
|
88 |
+
|
89 |
+
prompt = prompt.split("|")
|
90 |
+
if len(prompt) == 1:
|
91 |
+
prompt = prompt[0]
|
92 |
+
additional_prompts = None
|
93 |
+
else:
|
94 |
+
additional_prompts = prompt[1:]
|
95 |
+
prompt = prompt[0]
|
96 |
+
|
97 |
+
assert not (
|
98 |
+
(additional_prompts is not None) and loop
|
99 |
+
), "Do not provide additional prompts and set loop to True"
|
100 |
+
|
101 |
+
torch_device = torch.device(device)
|
102 |
+
if num_steps is None:
|
103 |
+
num_steps = 30
|
104 |
+
|
105 |
+
# allow for packing and conversion to latent space
|
106 |
+
img_height = 16 * (img_height // 16)
|
107 |
+
img_width = 16 * (img_width // 16)
|
108 |
+
|
109 |
+
# init all components
|
110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
111 |
+
|
112 |
+
ae = load_ae(
|
113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
114 |
+
autoencoder_params=AutoEncoderParams(),
|
115 |
+
device=torch_device,
|
116 |
+
dtype=torch.bfloat16,
|
117 |
+
)
|
118 |
+
clip_tokenizer = FluxTokenizer(
|
119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
120 |
+
)
|
121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
123 |
+
torch_device, dtype=torch.bfloat16
|
124 |
+
)
|
125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
126 |
+
torch_device, dtype=torch.bfloat16
|
127 |
+
)
|
128 |
+
|
129 |
+
rng = torch.Generator(device="cpu")
|
130 |
+
|
131 |
+
if seed is None:
|
132 |
+
seed = rng.seed()
|
133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
134 |
+
t0 = time.perf_counter()
|
135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
136 |
+
|
137 |
+
# Tokenize the prompt, on CPU
|
138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
140 |
+
|
141 |
+
batch = preprocess_flux_data(
|
142 |
+
device=torch_device,
|
143 |
+
dtype=torch.bfloat16,
|
144 |
+
autoencoder=None,
|
145 |
+
clip_encoder=clip_encoder,
|
146 |
+
t5_encoder=t5_encoder,
|
147 |
+
batch={
|
148 |
+
"clip_tokens": clip_tokens,
|
149 |
+
"t5_tokens": t5_tokens,
|
150 |
+
},
|
151 |
+
)
|
152 |
+
|
153 |
+
img = self._generate_images(
|
154 |
+
device=torch_device,
|
155 |
+
dtype=torch.bfloat16,
|
156 |
+
model=model,
|
157 |
+
decoder=ae,
|
158 |
+
img_width=img_width,
|
159 |
+
img_height=img_height,
|
160 |
+
denoising_steps=num_steps,
|
161 |
+
seed=seed,
|
162 |
+
clip_encodings=batch["clip_encodings"],
|
163 |
+
t5_encodings=batch["t5_encodings"],
|
164 |
+
guidance=guidance,
|
165 |
+
)
|
166 |
+
|
167 |
+
if torch.cuda.is_available():
|
168 |
+
torch.cuda.synchronize()
|
169 |
+
t1 = time.perf_counter()
|
170 |
+
|
171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
172 |
+
|
173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
174 |
+
|
175 |
+
def _generate_images(
|
176 |
+
self,
|
177 |
+
device: torch.device,
|
178 |
+
dtype: torch.dtype,
|
179 |
+
model: FluxModel,
|
180 |
+
decoder: AutoEncoder,
|
181 |
+
# image params:
|
182 |
+
img_width: int,
|
183 |
+
img_height: int,
|
184 |
+
# sampling params:
|
185 |
+
denoising_steps: int,
|
186 |
+
seed: int,
|
187 |
+
clip_encodings: torch.Tensor,
|
188 |
+
t5_encodings: torch.Tensor,
|
189 |
+
guidance: float = 4.0,
|
190 |
+
):
|
191 |
+
|
192 |
+
bsz = clip_encodings.shape[0]
|
193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
195 |
+
|
196 |
+
# create denoising schedule
|
197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
198 |
+
|
199 |
+
# create positional encodings
|
200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
203 |
+
).to(latents)
|
204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
205 |
+
|
206 |
+
# convert img-like latents into sequences of patches
|
207 |
+
latents = pack_latents(latents)
|
208 |
+
|
209 |
+
# this is ignored for schnell
|
210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
213 |
+
pred = model(
|
214 |
+
img=latents,
|
215 |
+
img_ids=latent_pos_enc,
|
216 |
+
txt=t5_encodings,
|
217 |
+
txt_ids=text_pos_enc,
|
218 |
+
y=clip_encodings,
|
219 |
+
timesteps=t_vec,
|
220 |
+
guidance=guidance_vec,
|
221 |
+
)
|
222 |
+
|
223 |
+
latents = latents + (t_prev - t_curr) * pred
|
224 |
+
|
225 |
+
# convert sequences of patches into img-like latents
|
226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
227 |
+
|
228 |
+
img = decoder.decode(latents)
|
229 |
+
return img
|
230 |
+
|
231 |
+
def _save_image(
|
232 |
+
self,
|
233 |
+
name: str,
|
234 |
+
output_name: str,
|
235 |
+
x: torch.Tensor,
|
236 |
+
add_sampling_metadata: bool,
|
237 |
+
prompt: str,
|
238 |
+
):
|
239 |
+
print(f"Saving {output_name}")
|
240 |
+
# bring into PIL format and save
|
241 |
+
x = x.clamp(-1, 1)
|
242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
243 |
+
|
244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
245 |
+
|
246 |
+
exif_data = Image.Exif()
|
247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
249 |
+
exif_data[ExifTags.Base.Model] = name
|
250 |
+
if add_sampling_metadata:
|
251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[job]
|
3 |
+
dump_folder = "./outputs"
|
4 |
+
description = "Flux debug model"
|
5 |
+
print_args = false
|
6 |
+
use_for_integration_test = true
|
7 |
+
|
8 |
+
[profiling]
|
9 |
+
enable_profiling = false
|
10 |
+
save_traces_folder = "profile_trace"
|
11 |
+
profile_freq = 10
|
12 |
+
enable_memory_snapshot = false
|
13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
14 |
+
|
15 |
+
[metrics]
|
16 |
+
log_freq = 1
|
17 |
+
disable_color_printing = false
|
18 |
+
enable_tensorboard = false
|
19 |
+
save_tb_folder = "tb"
|
20 |
+
enable_wandb = false
|
21 |
+
|
22 |
+
[model]
|
23 |
+
name = "flux"
|
24 |
+
flavor = "flux-debug"
|
25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
26 |
+
# test tokenizer.model, for debug purpose only
|
27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
28 |
+
# converters = "float8"
|
29 |
+
|
30 |
+
|
31 |
+
[optimizer]
|
32 |
+
name = "AdamW"
|
33 |
+
lr = 8e-4
|
34 |
+
eps = 1e-8
|
35 |
+
|
36 |
+
[lr_scheduler]
|
37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
39 |
+
decay_type = "linear"
|
40 |
+
lr_min = 0.0
|
41 |
+
|
42 |
+
[training]
|
43 |
+
batch_size = 32
|
44 |
+
seq_len = 512
|
45 |
+
max_norm = 1.0 # grad norm clipping
|
46 |
+
steps = 10
|
47 |
+
compile = false
|
48 |
+
dataset = "cc12m"
|
49 |
+
guidance = 3.5
|
50 |
+
seed = 0
|
51 |
+
|
52 |
+
[encoder]
|
53 |
+
t5_encoder="google/t5-v1_1-small"
|
54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
55 |
+
max_t5_encoding_len=512
|
56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
57 |
+
|
58 |
+
[parallelism]
|
59 |
+
data_parallel_replicate_degree = 1
|
60 |
+
data_parallel_shard_degree = 1
|
61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
62 |
+
tensor_parallel_degree = 1
|
63 |
+
enable_async_tensor_parallel = false
|
64 |
+
pipeline_parallel_degree = 1
|
65 |
+
context_parallel_degree = 1
|
66 |
+
|
67 |
+
[experimental]
|
68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
import logging
|
14 |
+
import time
|
15 |
+
|
16 |
+
# from typing import Dict, List, Optional, Tuple
|
17 |
+
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import triton
|
22 |
+
|
23 |
+
# import triton.language as tl
|
24 |
+
|
25 |
+
# Configure logging
|
26 |
+
logging.basicConfig(
|
27 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
28 |
+
)
|
29 |
+
|
30 |
+
# Try to import the optimized implementations
|
31 |
+
try:
|
32 |
+
from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
|
33 |
+
|
34 |
+
except ImportError:
|
35 |
+
logging.error(
|
36 |
+
"Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
|
37 |
+
)
|
38 |
+
raise
|
39 |
+
|
40 |
+
|
41 |
+
def compute_reference_forward(x, w, m_sizes):
|
42 |
+
"""
|
43 |
+
Reference PyTorch implementation of M*G grouped GEMM forward pass.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
47 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
48 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
torch.Tensor: Output tensor of shape (M, N)
|
52 |
+
"""
|
53 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
54 |
+
|
55 |
+
m_start = 0
|
56 |
+
for g in range(len(m_sizes)):
|
57 |
+
m_size = m_sizes[g].item()
|
58 |
+
if m_size > 0:
|
59 |
+
m_end = m_start + m_size
|
60 |
+
|
61 |
+
# Extract group input
|
62 |
+
x_g = x[m_start:m_end]
|
63 |
+
|
64 |
+
# Compute group output
|
65 |
+
y_g = torch.matmul(x_g, w.T)
|
66 |
+
|
67 |
+
# Store result
|
68 |
+
result[m_start:m_end] = y_g
|
69 |
+
|
70 |
+
# Update start index
|
71 |
+
m_start = m_end
|
72 |
+
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
@triton.testing.perf_report(
|
77 |
+
triton.testing.Benchmark(
|
78 |
+
x_names=["N"], # We'll vary the output dimension
|
79 |
+
x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
|
80 |
+
# x_vals=[8192, 16384],
|
81 |
+
line_arg="provider", # We'll compare different providers
|
82 |
+
line_vals=["pytorch_reference", "M*G grouped GEMM"],
|
83 |
+
line_names=["PyTorch Reference", "M*G grouped Kernel"],
|
84 |
+
styles=[("blue", "-"), ("red", "-")],
|
85 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
86 |
+
plot_name="mg_grouped_gemm_comparison",
|
87 |
+
args={
|
88 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
89 |
+
"K": 7168, # Hidden dimension, fixed for all tests
|
90 |
+
"G": 8, # Number of groups
|
91 |
+
"dtype": torch.float16,
|
92 |
+
"device": "cuda",
|
93 |
+
},
|
94 |
+
)
|
95 |
+
)
|
96 |
+
def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
97 |
+
"""
|
98 |
+
Benchmark the forward pass of the grouped GEMM implementation.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
M (int): Total batch size dimension
|
102 |
+
K (int): Hidden dimension
|
103 |
+
N (int): Output dimension
|
104 |
+
G (int): Number of groups
|
105 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
106 |
+
dtype (torch.dtype): Data type to use
|
107 |
+
device (str): Device to use
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
float: Performance in TFLOPS
|
111 |
+
"""
|
112 |
+
# Create group sizes for M dimension (balanced across groups)
|
113 |
+
base_size = M // G
|
114 |
+
remainder = M % G
|
115 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
116 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
117 |
+
|
118 |
+
print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
|
119 |
+
|
120 |
+
# Create input and weight tensors
|
121 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
122 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
123 |
+
|
124 |
+
# Pre-compute for PyTorch reference to ensure fair comparison
|
125 |
+
if provider == "pytorch_reference":
|
126 |
+
# Warmup
|
127 |
+
torch.cuda.synchronize()
|
128 |
+
compute_reference_forward(x, w, m_sizes)
|
129 |
+
torch.cuda.synchronize()
|
130 |
+
|
131 |
+
# Benchmark
|
132 |
+
start_time = time.time()
|
133 |
+
for _ in range(10): # Average over 10 runs
|
134 |
+
compute_reference_forward(x, w, m_sizes)
|
135 |
+
torch.cuda.synchronize()
|
136 |
+
end_time = time.time()
|
137 |
+
else: # Optimized kernel
|
138 |
+
# Warmup
|
139 |
+
torch.cuda.synchronize()
|
140 |
+
grouped_gemm_forward(x, w, m_sizes)
|
141 |
+
torch.cuda.synchronize()
|
142 |
+
|
143 |
+
# Benchmark
|
144 |
+
start_time = time.time()
|
145 |
+
for _ in range(10): # Average over 10 runs
|
146 |
+
grouped_gemm_forward(x, w, m_sizes)
|
147 |
+
torch.cuda.synchronize()
|
148 |
+
end_time = time.time()
|
149 |
+
|
150 |
+
# Calculate FLOPs
|
151 |
+
# For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
|
152 |
+
flops = 2 * M * N * K
|
153 |
+
|
154 |
+
# Convert to TFLOPS (tera-FLOPS)
|
155 |
+
avg_time = (end_time - start_time) / 10 # Average time per run
|
156 |
+
tflops = flops / avg_time / 1e12
|
157 |
+
|
158 |
+
return tflops
|
159 |
+
|
160 |
+
|
161 |
+
@triton.testing.perf_report(
|
162 |
+
triton.testing.Benchmark(
|
163 |
+
x_names=["G"], # We'll vary the number of groups
|
164 |
+
x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
|
165 |
+
line_arg="provider", # We'll compare different providers
|
166 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
167 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
168 |
+
styles=[("blue", "-"), ("red", "-")],
|
169 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
170 |
+
plot_name="mg_grouped_gemm_group_scaling",
|
171 |
+
args={
|
172 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
173 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
174 |
+
"N": 8192, # Output dimension, fixed for all tests
|
175 |
+
"dtype": torch.float16,
|
176 |
+
"device": "cuda",
|
177 |
+
},
|
178 |
+
)
|
179 |
+
)
|
180 |
+
def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
|
181 |
+
"""
|
182 |
+
Benchmark how performance scales with number of groups.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
M (int): Total batch size dimension
|
186 |
+
K (int): Hidden dimension
|
187 |
+
N (int): Output dimension
|
188 |
+
G (int): Number of groups
|
189 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
190 |
+
dtype (torch.dtype): Data type to use
|
191 |
+
device (str): Device to use
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
float: Performance in TFLOPS
|
195 |
+
"""
|
196 |
+
# Create group sizes for M dimension (balanced across groups)
|
197 |
+
base_size = M // G
|
198 |
+
remainder = M % G
|
199 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
200 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
201 |
+
|
202 |
+
# Create input and weight tensors
|
203 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
204 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
205 |
+
|
206 |
+
# Benchmark logic - same as previous function
|
207 |
+
if provider == "pytorch_reference":
|
208 |
+
torch.cuda.synchronize()
|
209 |
+
compute_reference_forward(x, w, m_sizes)
|
210 |
+
torch.cuda.synchronize()
|
211 |
+
|
212 |
+
start_time = time.time()
|
213 |
+
for _ in range(10):
|
214 |
+
compute_reference_forward(x, w, m_sizes)
|
215 |
+
torch.cuda.synchronize()
|
216 |
+
end_time = time.time()
|
217 |
+
else:
|
218 |
+
torch.cuda.synchronize()
|
219 |
+
grouped_gemm_forward(x, w, m_sizes)
|
220 |
+
torch.cuda.synchronize()
|
221 |
+
|
222 |
+
start_time = time.time()
|
223 |
+
for _ in range(10):
|
224 |
+
grouped_gemm_forward(x, w, m_sizes)
|
225 |
+
torch.cuda.synchronize()
|
226 |
+
end_time = time.time()
|
227 |
+
|
228 |
+
# Calculate FLOPs and TFLOPS
|
229 |
+
flops = 2 * M * N * K
|
230 |
+
avg_time = (end_time - start_time) / 10
|
231 |
+
tflops = flops / avg_time / 1e12
|
232 |
+
|
233 |
+
return tflops
|
234 |
+
|
235 |
+
|
236 |
+
@triton.testing.perf_report(
|
237 |
+
triton.testing.Benchmark(
|
238 |
+
x_names=["group_balance"], # We'll vary the group balance factor
|
239 |
+
x_vals=[
|
240 |
+
0.0,
|
241 |
+
0.25,
|
242 |
+
0.5,
|
243 |
+
0.75,
|
244 |
+
0.9,
|
245 |
+
], # Different imbalance factors (0 = balanced, 1 = max imbalance)
|
246 |
+
line_arg="provider", # We'll compare different providers
|
247 |
+
line_vals=["pytorch_reference", "optimized_kernel"],
|
248 |
+
line_names=["PyTorch Reference", "Optimized Kernel"],
|
249 |
+
styles=[("blue", "-"), ("red", "-")],
|
250 |
+
ylabel="TFLOPS", # We'll measure TFLOPS
|
251 |
+
plot_name="mg_grouped_gemm_imbalance",
|
252 |
+
args={
|
253 |
+
"M": 8192, # Batch dimension, fixed for all tests
|
254 |
+
"K": 4096, # Hidden dimension, fixed for all tests
|
255 |
+
"N": 8192, # Output dimension, fixed for all tests
|
256 |
+
"G": 4, # Number of groups
|
257 |
+
"dtype": torch.float16,
|
258 |
+
"device": "cuda",
|
259 |
+
},
|
260 |
+
)
|
261 |
+
)
|
262 |
+
def benchmark_imbalance(
|
263 |
+
M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Benchmark how performance is affected by imbalanced group sizes.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
M (int): Total batch size dimension
|
270 |
+
K (int): Hidden dimension
|
271 |
+
N (int): Output dimension
|
272 |
+
G (int): Number of groups
|
273 |
+
group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
|
274 |
+
provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
|
275 |
+
dtype (torch.dtype): Data type to use
|
276 |
+
device (str): Device to use
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
float: Performance in TFLOPS
|
280 |
+
"""
|
281 |
+
# Create imbalanced group sizes for M dimension
|
282 |
+
if group_balance == 0:
|
283 |
+
# Balanced case
|
284 |
+
base_size = M // G
|
285 |
+
remainder = M % G
|
286 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
287 |
+
else:
|
288 |
+
# Imbalanced case
|
289 |
+
# First group gets more elements, last group gets fewer
|
290 |
+
# The imbalance is controlled by the group_balance factor
|
291 |
+
remaining = M
|
292 |
+
M_sizes = []
|
293 |
+
for g in range(G):
|
294 |
+
# Interpolate from balanced to imbalanced based on group_balance
|
295 |
+
# For balanced (group_balance=0), each group gets M/G
|
296 |
+
# For imbalanced (group_balance=1), first group gets much more than last group
|
297 |
+
balanced_size = remaining // (G - g)
|
298 |
+
|
299 |
+
# Adjusting size based on position and imbalance factor
|
300 |
+
# First groups get more, last groups get less
|
301 |
+
if g < G // 2:
|
302 |
+
# First half of groups get more
|
303 |
+
adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
|
304 |
+
size = balanced_size + adjustment
|
305 |
+
else:
|
306 |
+
# Second half of groups get less
|
307 |
+
adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
|
308 |
+
size = balanced_size - adjustment
|
309 |
+
|
310 |
+
# Ensure we don't go below 1 or take more than remaining
|
311 |
+
size = max(1, min(size, remaining))
|
312 |
+
M_sizes.append(size)
|
313 |
+
remaining -= size
|
314 |
+
|
315 |
+
# Handle any remaining elements
|
316 |
+
if remaining > 0:
|
317 |
+
M_sizes[-1] += remaining
|
318 |
+
|
319 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
320 |
+
|
321 |
+
# Create input and weight tensors
|
322 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
323 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
324 |
+
|
325 |
+
# Benchmark logic
|
326 |
+
if provider == "pytorch_reference":
|
327 |
+
torch.cuda.synchronize()
|
328 |
+
compute_reference_forward(x, w, m_sizes)
|
329 |
+
torch.cuda.synchronize()
|
330 |
+
|
331 |
+
start_time = time.time()
|
332 |
+
for _ in range(10):
|
333 |
+
compute_reference_forward(x, w, m_sizes)
|
334 |
+
torch.cuda.synchronize()
|
335 |
+
end_time = time.time()
|
336 |
+
else:
|
337 |
+
torch.cuda.synchronize()
|
338 |
+
grouped_gemm_forward(x, w, m_sizes)
|
339 |
+
torch.cuda.synchronize()
|
340 |
+
|
341 |
+
start_time = time.time()
|
342 |
+
for _ in range(10):
|
343 |
+
grouped_gemm_forward(x, w, m_sizes)
|
344 |
+
torch.cuda.synchronize()
|
345 |
+
end_time = time.time()
|
346 |
+
|
347 |
+
# Calculate FLOPs and TFLOPS
|
348 |
+
flops = 2 * M * N * K
|
349 |
+
avg_time = (end_time - start_time) / 10
|
350 |
+
tflops = flops / avg_time / 1e12
|
351 |
+
|
352 |
+
return tflops
|
353 |
+
|
354 |
+
|
355 |
+
def benchmark_model_configs():
|
356 |
+
"""
|
357 |
+
Benchmark common model configurations used in DeepSeek-like models.
|
358 |
+
"""
|
359 |
+
# Model configurations: (M, K, N, G)
|
360 |
+
configs = [
|
361 |
+
(8192, 7168, 4096, 4), # Config 1
|
362 |
+
(8192, 2048, 7168, 4), # Config 2
|
363 |
+
(4096, 7168, 4096, 8), # Config 3
|
364 |
+
(4096, 2048, 7168, 8), # Config 4
|
365 |
+
]
|
366 |
+
|
367 |
+
results = []
|
368 |
+
|
369 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
370 |
+
dtype = torch.float16
|
371 |
+
|
372 |
+
for config_idx, (M, K, N, G) in enumerate(configs):
|
373 |
+
logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
|
374 |
+
logging.info(f"M={M}, K={K}, N={N}, G={G}")
|
375 |
+
|
376 |
+
# Create group sizes for M dimension
|
377 |
+
base_size = M // G
|
378 |
+
remainder = M % G
|
379 |
+
M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
|
380 |
+
m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
|
381 |
+
|
382 |
+
# Create tensors
|
383 |
+
x = torch.randn(M, K, dtype=dtype, device=device)
|
384 |
+
w = torch.randn(N, K, dtype=dtype, device=device)
|
385 |
+
|
386 |
+
# Benchmark PyTorch reference
|
387 |
+
torch.cuda.synchronize()
|
388 |
+
compute_reference_forward(x, w, m_sizes) # Warmup
|
389 |
+
torch.cuda.synchronize()
|
390 |
+
|
391 |
+
logging.info("Benchmarking PyTorch reference...")
|
392 |
+
torch.cuda.reset_peak_memory_stats()
|
393 |
+
start_time = time.time()
|
394 |
+
for _ in range(10):
|
395 |
+
compute_reference_forward(x, w, m_sizes)
|
396 |
+
torch.cuda.synchronize()
|
397 |
+
end_time = time.time()
|
398 |
+
pt_time = (end_time - start_time) / 10
|
399 |
+
pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
400 |
+
|
401 |
+
# Benchmark optimized kernel
|
402 |
+
torch.cuda.synchronize()
|
403 |
+
grouped_gemm_forward(x, w, m_sizes) # Warmup
|
404 |
+
torch.cuda.synchronize()
|
405 |
+
|
406 |
+
logging.info("Benchmarking optimized kernel...")
|
407 |
+
torch.cuda.reset_peak_memory_stats()
|
408 |
+
start_time = time.time()
|
409 |
+
for _ in range(10):
|
410 |
+
grouped_gemm_forward(x, w, m_sizes)
|
411 |
+
torch.cuda.synchronize()
|
412 |
+
end_time = time.time()
|
413 |
+
opt_time = (end_time - start_time) / 10
|
414 |
+
opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
|
415 |
+
|
416 |
+
# Calculate FLOPs and speedup
|
417 |
+
flops = 2 * M * N * K
|
418 |
+
pt_tflops = flops / pt_time / 1e12
|
419 |
+
opt_tflops = flops / opt_time / 1e12
|
420 |
+
speedup = pt_time / opt_time
|
421 |
+
|
422 |
+
# Store results
|
423 |
+
results.append(
|
424 |
+
{
|
425 |
+
"config": f"Config {config_idx + 1}",
|
426 |
+
"dimensions": f"M={M}, K={K}, N={N}, G={G}",
|
427 |
+
"pt_time_ms": pt_time * 1000,
|
428 |
+
"opt_time_ms": opt_time * 1000,
|
429 |
+
"pt_tflops": pt_tflops,
|
430 |
+
"opt_tflops": opt_tflops,
|
431 |
+
"speedup": speedup,
|
432 |
+
"pt_memory_mb": pt_memory,
|
433 |
+
"opt_memory_mb": opt_memory,
|
434 |
+
"memory_savings": (
|
435 |
+
(pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
|
436 |
+
),
|
437 |
+
}
|
438 |
+
)
|
439 |
+
|
440 |
+
logging.info(
|
441 |
+
f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
|
442 |
+
)
|
443 |
+
logging.info(
|
444 |
+
f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
|
445 |
+
)
|
446 |
+
logging.info(
|
447 |
+
f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
|
448 |
+
)
|
449 |
+
|
450 |
+
# Print summary table
|
451 |
+
logging.info("\n===== Benchmark Results Summary =====")
|
452 |
+
logging.info(
|
453 |
+
f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
|
454 |
+
)
|
455 |
+
logging.info(
|
456 |
+
f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
|
457 |
+
f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
|
458 |
+
)
|
459 |
+
logging.info("-" * 100)
|
460 |
+
|
461 |
+
for result in results:
|
462 |
+
logging.info(
|
463 |
+
f"{result['config']:<10} | "
|
464 |
+
f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
|
465 |
+
f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
|
466 |
+
f"{result['speedup']:<10.2f} | "
|
467 |
+
f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
|
468 |
+
f"{result['memory_savings']:<12.2f}%"
|
469 |
+
)
|
470 |
+
|
471 |
+
return results
|
472 |
+
|
473 |
+
|
474 |
+
def plot_benchmark_results(results):
|
475 |
+
"""
|
476 |
+
Plot benchmark results as bar charts.
|
477 |
+
"""
|
478 |
+
# Extract data
|
479 |
+
configs = [r["config"] for r in results]
|
480 |
+
pt_tflops = [r["pt_tflops"] for r in results]
|
481 |
+
opt_tflops = [r["opt_tflops"] for r in results]
|
482 |
+
speedups = [r["speedup"] for r in results]
|
483 |
+
|
484 |
+
# Create figure with subplots
|
485 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
486 |
+
|
487 |
+
# Plot TFLOPS comparison
|
488 |
+
x = np.arange(len(configs))
|
489 |
+
width = 0.35
|
490 |
+
ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
|
491 |
+
ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
|
492 |
+
ax1.set_xlabel("Model Configuration")
|
493 |
+
ax1.set_ylabel("TFLOPS")
|
494 |
+
ax1.set_title("Performance Comparison (Higher is Better)")
|
495 |
+
ax1.set_xticks(x)
|
496 |
+
ax1.set_xticklabels(configs)
|
497 |
+
ax1.legend()
|
498 |
+
ax1.grid(axis="y", linestyle="--", alpha=0.7)
|
499 |
+
|
500 |
+
# Plot speedup
|
501 |
+
ax2.bar(x, speedups, width=0.6, color="green")
|
502 |
+
ax2.set_xlabel("Model Configuration")
|
503 |
+
ax2.set_ylabel("Speedup (x)")
|
504 |
+
ax2.set_title("Speedup Factor (Higher is Better)")
|
505 |
+
ax2.set_xticks(x)
|
506 |
+
ax2.set_xticklabels(configs)
|
507 |
+
ax2.grid(axis="y", linestyle="--", alpha=0.7)
|
508 |
+
|
509 |
+
# Add speedup values on top of bars
|
510 |
+
for i, v in enumerate(speedups):
|
511 |
+
ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
|
512 |
+
|
513 |
+
plt.tight_layout()
|
514 |
+
plt.savefig("mg_grouped_gemm_benchmark_results.png")
|
515 |
+
logging.info(
|
516 |
+
"Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
def compare_mg_implementations():
|
521 |
+
"""
|
522 |
+
Combine the M*G and N*G benchmark results for comparison.
|
523 |
+
"""
|
524 |
+
# Only run this if both NG and MG benchmarks have been run
|
525 |
+
try:
|
526 |
+
import pandas as pd
|
527 |
+
|
528 |
+
# Try to load previous benchmark results
|
529 |
+
mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
|
530 |
+
ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
|
531 |
+
|
532 |
+
# Create comparison plot
|
533 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
534 |
+
|
535 |
+
# Plot speedup comparison
|
536 |
+
configs = mg_results["config"].unique()
|
537 |
+
mg_speedups = mg_results.groupby("config")["speedup"].mean()
|
538 |
+
ng_speedups = ng_results.groupby("config")["speedup"].mean()
|
539 |
+
|
540 |
+
x = np.arange(len(configs))
|
541 |
+
width = 0.35
|
542 |
+
|
543 |
+
axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
|
544 |
+
axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
|
545 |
+
axes[0].set_xlabel("Model Configuration")
|
546 |
+
axes[0].set_ylabel("Speedup (x)")
|
547 |
+
axes[0].set_title("Speedup Comparison: M*G vs N*G")
|
548 |
+
axes[0].set_xticks(x)
|
549 |
+
axes[0].set_xticklabels(configs)
|
550 |
+
axes[0].legend()
|
551 |
+
axes[0].grid(axis="y", linestyle="--", alpha=0.7)
|
552 |
+
|
553 |
+
# Plot TFLOPS comparison for optimized kernels
|
554 |
+
mg_tflops = (
|
555 |
+
mg_results[mg_results["implementation"] == "optimized"]
|
556 |
+
.groupby("config")["tflops"]
|
557 |
+
.mean()
|
558 |
+
)
|
559 |
+
ng_tflops = (
|
560 |
+
ng_results[ng_results["implementation"] == "optimized"]
|
561 |
+
.groupby("config")["tflops"]
|
562 |
+
.mean()
|
563 |
+
)
|
564 |
+
|
565 |
+
axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
|
566 |
+
axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
|
567 |
+
axes[1].set_xlabel("Model Configuration")
|
568 |
+
axes[1].set_ylabel("TFLOPS")
|
569 |
+
axes[1].set_title("Performance Comparison: M*G vs N*G")
|
570 |
+
axes[1].set_xticks(x)
|
571 |
+
axes[1].set_xticklabels(configs)
|
572 |
+
axes[1].legend()
|
573 |
+
axes[1].grid(axis="y", linestyle="--", alpha=0.7)
|
574 |
+
|
575 |
+
plt.tight_layout()
|
576 |
+
plt.savefig("mg_vs_ng_comparison.png")
|
577 |
+
logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
|
578 |
+
|
579 |
+
except Exception as e:
|
580 |
+
logging.error(f"Could not create comparison plot: {e}")
|
581 |
+
logging.info(
|
582 |
+
"Run both M*G and N*G benchmarks first to generate comparison plots"
|
583 |
+
)
|
584 |
+
|
585 |
+
|
586 |
+
if __name__ == "__main__":
|
587 |
+
parser = argparse.ArgumentParser(
|
588 |
+
description="Benchmark M*G Grouped GEMM implementations"
|
589 |
+
)
|
590 |
+
parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
|
591 |
+
parser.add_argument(
|
592 |
+
"--triton-bench", action="store_true", help="Run Triton performance reports"
|
593 |
+
)
|
594 |
+
parser.add_argument(
|
595 |
+
"--model-configs", action="store_true", help="Benchmark model configurations"
|
596 |
+
)
|
597 |
+
parser.add_argument(
|
598 |
+
"--compare-mg-ng",
|
599 |
+
action="store_true",
|
600 |
+
help="Compare M*G and N*G implementations",
|
601 |
+
)
|
602 |
+
args = parser.parse_args()
|
603 |
+
|
604 |
+
# Check if CUDA is available
|
605 |
+
if not torch.cuda.is_available():
|
606 |
+
logging.error(
|
607 |
+
"CUDA is not available. This benchmark requires a CUDA-capable GPU."
|
608 |
+
)
|
609 |
+
exit(1)
|
610 |
+
|
611 |
+
if args.run_all or args.model_configs:
|
612 |
+
# Benchmark model configurations
|
613 |
+
logging.info("Running benchmark for model configurations...")
|
614 |
+
results = benchmark_model_configs()
|
615 |
+
plot_benchmark_results(results)
|
616 |
+
|
617 |
+
if args.run_all or args.triton_bench:
|
618 |
+
# Run Triton performance reports
|
619 |
+
logging.info("Running Triton performance reports...")
|
620 |
+
benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
|
621 |
+
benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
|
622 |
+
benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
|
623 |
+
logging.info(
|
624 |
+
"Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
|
625 |
+
)
|
626 |
+
|
627 |
+
if args.run_all or args.compare_mg_ng:
|
628 |
+
# Compare M*G and N*G implementations
|
629 |
+
logging.info("Comparing M*G and N*G implementations...")
|
630 |
+
compare_mg_implementations()
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# credit - flat index forward kernel is derived from FBGemm:
|
8 |
+
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
9 |
+
|
10 |
+
# pyre-unsafe
|
11 |
+
import functools
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from typing import Any, Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
from triton import Config as TConfig
|
23 |
+
|
24 |
+
from triton.runtime import driver # @manual
|
25 |
+
|
26 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
27 |
+
|
28 |
+
from tma_autotuning import (
|
29 |
+
ALIGN_SIZE_M,
|
30 |
+
_NV_CONFIGS,
|
31 |
+
CudaUtils,
|
32 |
+
early_config_prune,
|
33 |
+
TmaDescriptorHelper,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
# Configure logging
|
38 |
+
logging.basicConfig(
|
39 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
40 |
+
)
|
41 |
+
|
42 |
+
# ============== Start Triton Kernels ===============
|
43 |
+
|
44 |
+
|
45 |
+
@triton.autotune(
|
46 |
+
configs=_NV_CONFIGS,
|
47 |
+
key=["G", "M_BUCKET", "N", "K"],
|
48 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
49 |
+
)
|
50 |
+
@triton.jit
|
51 |
+
def _kernel_mg_forward_hopper(
|
52 |
+
a_desc_ptr,
|
53 |
+
b_desc_ptr,
|
54 |
+
c_ptr,
|
55 |
+
workspace,
|
56 |
+
m_sizes,
|
57 |
+
# problem sizes
|
58 |
+
G: tl.constexpr,
|
59 |
+
M_BUCKET: tl.constexpr,
|
60 |
+
N: tl.constexpr,
|
61 |
+
K: tl.constexpr,
|
62 |
+
# config
|
63 |
+
NUM_SMS: tl.constexpr,
|
64 |
+
TMA_SIZE: tl.constexpr,
|
65 |
+
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
66 |
+
# tiles
|
67 |
+
BLOCK_SIZE_M: tl.constexpr,
|
68 |
+
BLOCK_SIZE_N: tl.constexpr,
|
69 |
+
BLOCK_SIZE_K: tl.constexpr,
|
70 |
+
) -> None:
|
71 |
+
"""
|
72 |
+
Flat index style forward kernel for Hopper.
|
73 |
+
For simplicity, we always use TMA Load and TMA Store
|
74 |
+
"""
|
75 |
+
tbidx = tl.program_id(0) # thread block index
|
76 |
+
|
77 |
+
c_dtype = c_ptr.dtype.element_ty # output dtype
|
78 |
+
|
79 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
|
80 |
+
|
81 |
+
M_end = 0
|
82 |
+
M_start = 0
|
83 |
+
processed_tiles = 0
|
84 |
+
# Size of individual weight matrix
|
85 |
+
n_size = N // G
|
86 |
+
n_start = 0
|
87 |
+
|
88 |
+
for g in range(G):
|
89 |
+
# Move down along groups
|
90 |
+
# reset to new M offset
|
91 |
+
M_start = M_end
|
92 |
+
m_size = tl.load(m_sizes + g)
|
93 |
+
M_end = M_start + m_size
|
94 |
+
n_start = n_size * g
|
95 |
+
|
96 |
+
if m_size > 0:
|
97 |
+
# Process this group
|
98 |
+
|
99 |
+
# Acquire hold on c_desc_ptr for TMA Store
|
100 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
101 |
+
desc_ptr=c_desc_ptr,
|
102 |
+
global_address=c_ptr + M_start * n_size,
|
103 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
104 |
+
global_size=[m_size, n_size],
|
105 |
+
element_ty=c_dtype,
|
106 |
+
)
|
107 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
108 |
+
|
109 |
+
# tiles for this group
|
110 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
111 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
112 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
113 |
+
|
114 |
+
while tbidx >= processed_tiles and tbidx < (
|
115 |
+
processed_tiles + group_num_tiles
|
116 |
+
):
|
117 |
+
group_index = tbidx - processed_tiles
|
118 |
+
|
119 |
+
# columnwise
|
120 |
+
tile_m_index = group_index % num_m_tiles
|
121 |
+
tile_n_index = group_index // num_m_tiles
|
122 |
+
|
123 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
124 |
+
|
125 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
126 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
127 |
+
global_n_offset = (n_start + n_offset).to(tl.int32)
|
128 |
+
|
129 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
130 |
+
# input block [M,K]
|
131 |
+
a = tl._experimental_descriptor_load(
|
132 |
+
a_desc_ptr,
|
133 |
+
[m_offset, k_offset],
|
134 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
135 |
+
c_dtype,
|
136 |
+
)
|
137 |
+
# weight block [N, K]
|
138 |
+
b = tl._experimental_descriptor_load(
|
139 |
+
b_desc_ptr,
|
140 |
+
[global_n_offset, k_offset],
|
141 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
142 |
+
c_dtype,
|
143 |
+
)
|
144 |
+
|
145 |
+
accumulator += tl.dot(a, b.T)
|
146 |
+
|
147 |
+
# Store using TMA
|
148 |
+
|
149 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
150 |
+
|
151 |
+
if USE_EPILOGUE_SUBTILING:
|
152 |
+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
153 |
+
acc = tl.permute(acc, (0, 2, 1))
|
154 |
+
acc0, acc1 = tl.split(acc)
|
155 |
+
c0 = acc0.to(c_dtype)
|
156 |
+
tl._experimental_descriptor_store(
|
157 |
+
c_desc_ptr, c0, [m_offset, n_offset]
|
158 |
+
)
|
159 |
+
c1 = acc1.to(c_dtype)
|
160 |
+
tl._experimental_descriptor_store(
|
161 |
+
c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
tl._experimental_descriptor_store(
|
165 |
+
c_desc_ptr,
|
166 |
+
accumulator.to(c_dtype),
|
167 |
+
[m_offset, n_offset],
|
168 |
+
)
|
169 |
+
# move to next tile in group
|
170 |
+
tbidx += NUM_SMS
|
171 |
+
# Update the total tiles count for the next group
|
172 |
+
processed_tiles += group_num_tiles
|
173 |
+
|
174 |
+
|
175 |
+
@triton.autotune(
|
176 |
+
configs=_NV_CONFIGS,
|
177 |
+
key=["G", "M_BUCKET", "N", "K"],
|
178 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
179 |
+
)
|
180 |
+
@triton.jit
|
181 |
+
def _kernel_mg_forward_tma(
|
182 |
+
a_desc_ptr,
|
183 |
+
b_desc_ptr,
|
184 |
+
c_ptr,
|
185 |
+
workspace,
|
186 |
+
m_sizes,
|
187 |
+
a_scale_ptr,
|
188 |
+
b_scale_ptr,
|
189 |
+
# problem sizes
|
190 |
+
G: tl.constexpr,
|
191 |
+
M_BUCKET: tl.constexpr,
|
192 |
+
N: tl.constexpr,
|
193 |
+
K: tl.constexpr,
|
194 |
+
# config
|
195 |
+
NUM_SMS: tl.constexpr,
|
196 |
+
USE_TMA_LOAD: tl.constexpr,
|
197 |
+
USE_TMA_STORE: tl.constexpr,
|
198 |
+
TMA_SIZE: tl.constexpr,
|
199 |
+
USE_FP8: tl.constexpr,
|
200 |
+
# tiles
|
201 |
+
BLOCK_SIZE_M: tl.constexpr,
|
202 |
+
BLOCK_SIZE_N: tl.constexpr,
|
203 |
+
BLOCK_SIZE_K: tl.constexpr,
|
204 |
+
) -> None:
|
205 |
+
"""
|
206 |
+
Flat index style forward kernel.
|
207 |
+
For simplicity, we always use TMA Load and TMA Store
|
208 |
+
"""
|
209 |
+
tbidx = tl.program_id(0) # thread block index
|
210 |
+
|
211 |
+
c_dtype = c_ptr.dtype.element_ty
|
212 |
+
|
213 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
214 |
+
|
215 |
+
M_end = 0
|
216 |
+
processed_tiles = 0
|
217 |
+
|
218 |
+
for g in range(G):
|
219 |
+
# Move down along groups
|
220 |
+
# reset to new M offset
|
221 |
+
M_start = M_end
|
222 |
+
m_size = tl.load(m_sizes + g)
|
223 |
+
M_end = M_start + m_size
|
224 |
+
|
225 |
+
if m_size > 0:
|
226 |
+
# Process this group
|
227 |
+
n_size = N
|
228 |
+
|
229 |
+
# TMA Store prep
|
230 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
231 |
+
desc_ptr=c_desc_ptr,
|
232 |
+
global_address=c_ptr + M_start * N,
|
233 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
234 |
+
global_size=[m_size, n_size],
|
235 |
+
element_ty=c_dtype,
|
236 |
+
)
|
237 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
238 |
+
|
239 |
+
# tiles for this group
|
240 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
241 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
242 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
243 |
+
|
244 |
+
while tbidx >= processed_tiles and tbidx < (
|
245 |
+
processed_tiles + group_num_tiles
|
246 |
+
):
|
247 |
+
group_index = tbidx - processed_tiles
|
248 |
+
|
249 |
+
tile_m_index = group_index % num_m_tiles
|
250 |
+
tile_n_index = group_index // num_m_tiles
|
251 |
+
|
252 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
253 |
+
|
254 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
255 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
256 |
+
|
257 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
258 |
+
# input block [M,K]
|
259 |
+
a = tl._experimental_descriptor_load(
|
260 |
+
a_desc_ptr,
|
261 |
+
[m_offset, k_offset],
|
262 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
263 |
+
c_dtype,
|
264 |
+
)
|
265 |
+
# weight block [N, K]
|
266 |
+
b = tl._experimental_descriptor_load(
|
267 |
+
b_desc_ptr,
|
268 |
+
[n_offset, k_offset],
|
269 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
270 |
+
c_dtype,
|
271 |
+
)
|
272 |
+
|
273 |
+
accumulator += tl.dot(a, b.T)
|
274 |
+
|
275 |
+
# Store using TMA
|
276 |
+
|
277 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
278 |
+
# n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
279 |
+
|
280 |
+
tl._experimental_descriptor_store(
|
281 |
+
c_desc_ptr,
|
282 |
+
accumulator.to(c_dtype),
|
283 |
+
[m_offset, n_offset],
|
284 |
+
)
|
285 |
+
|
286 |
+
# Move to the next tile
|
287 |
+
tbidx += NUM_SMS
|
288 |
+
# Update the total tiles count for the next group
|
289 |
+
processed_tiles += group_num_tiles
|
290 |
+
|
291 |
+
|
292 |
+
@triton.autotune(
|
293 |
+
configs=_NV_CONFIGS,
|
294 |
+
key=["G", "M_BUCKET", "N", "K"],
|
295 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
296 |
+
)
|
297 |
+
@triton.jit
|
298 |
+
def _kernel_mg_forward_no_tma(
|
299 |
+
a_ptr,
|
300 |
+
b_ptr,
|
301 |
+
c_ptr,
|
302 |
+
workspace,
|
303 |
+
m_sizes,
|
304 |
+
# problem sizes
|
305 |
+
G: tl.constexpr,
|
306 |
+
M_BUCKET: tl.constexpr,
|
307 |
+
N: tl.constexpr,
|
308 |
+
K: tl.constexpr,
|
309 |
+
# config
|
310 |
+
NUM_SMS: tl.constexpr,
|
311 |
+
USE_TMA_LOAD: tl.constexpr,
|
312 |
+
USE_TMA_STORE: tl.constexpr,
|
313 |
+
TMA_SIZE: tl.constexpr,
|
314 |
+
# tiles
|
315 |
+
BLOCK_SIZE_M: tl.constexpr,
|
316 |
+
BLOCK_SIZE_N: tl.constexpr,
|
317 |
+
BLOCK_SIZE_K: tl.constexpr,
|
318 |
+
) -> None:
|
319 |
+
"""
|
320 |
+
Flat index style forward kernel.
|
321 |
+
For bc and Ampere, we never use TMA Load and TMA Store
|
322 |
+
"""
|
323 |
+
tbidx = tl.program_id(0) # thread block index
|
324 |
+
|
325 |
+
c_dtype = c_ptr.dtype.element_ty
|
326 |
+
c_desc_ptr = None
|
327 |
+
|
328 |
+
M_end = 0
|
329 |
+
processed_tiles = 0
|
330 |
+
|
331 |
+
for g in range(G):
|
332 |
+
# Move down along groups
|
333 |
+
# reset to new M offset
|
334 |
+
M_start = M_end
|
335 |
+
m_size = tl.load(m_sizes + g)
|
336 |
+
M_end = M_start + m_size
|
337 |
+
|
338 |
+
if m_size > 0:
|
339 |
+
# Process this group
|
340 |
+
n_size = N
|
341 |
+
|
342 |
+
# tiles for this group
|
343 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
344 |
+
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
345 |
+
group_num_tiles = num_m_tiles * num_n_tiles
|
346 |
+
|
347 |
+
while tbidx >= processed_tiles and tbidx < (
|
348 |
+
processed_tiles + group_num_tiles
|
349 |
+
):
|
350 |
+
group_index = tbidx - processed_tiles
|
351 |
+
|
352 |
+
tile_m_index = group_index % num_m_tiles
|
353 |
+
tile_n_index = group_index // num_m_tiles
|
354 |
+
|
355 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
356 |
+
|
357 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
358 |
+
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
359 |
+
|
360 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
361 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
362 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
363 |
+
|
364 |
+
a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
|
365 |
+
b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
|
366 |
+
|
367 |
+
for k_offset in range(0, K, BLOCK_SIZE_K):
|
368 |
+
# Load with bounds checking
|
369 |
+
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
370 |
+
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
371 |
+
|
372 |
+
# Main matmul
|
373 |
+
accumulator += tl.dot(a, b.T)
|
374 |
+
|
375 |
+
# Update pointers for next block
|
376 |
+
a_ptrs += BLOCK_SIZE_K
|
377 |
+
b_ptrs += BLOCK_SIZE_K
|
378 |
+
|
379 |
+
# Store without TMA
|
380 |
+
offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
381 |
+
offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
382 |
+
|
383 |
+
c = accumulator.to(c_dtype)
|
384 |
+
|
385 |
+
tl.store(
|
386 |
+
c_ptr
|
387 |
+
+ (M_start + offs_am[:, None]) * N # Row stride is N
|
388 |
+
+ offs_bn[None, :], # Column offset
|
389 |
+
c,
|
390 |
+
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
391 |
+
)
|
392 |
+
# Move to the next tile
|
393 |
+
tbidx += NUM_SMS
|
394 |
+
# Update the total tiles count for the next group
|
395 |
+
processed_tiles += group_num_tiles
|
396 |
+
|
397 |
+
|
398 |
+
"""
|
399 |
+
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
400 |
+
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
401 |
+
"""
|
402 |
+
|
403 |
+
|
404 |
+
# ---- dx flat linear indexed ----
|
405 |
+
@triton.autotune(
|
406 |
+
configs=_NV_CONFIGS,
|
407 |
+
key=["G", "M_BUCKET", "N", "K"],
|
408 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
409 |
+
)
|
410 |
+
@triton.jit
|
411 |
+
def _kernel_mg_dx_tma(
|
412 |
+
grad_output_desc_ptr, # [MG, N]
|
413 |
+
w_desc_ptr, # [N, K]
|
414 |
+
grad_input_ptr, # output grad_x [MG, K]
|
415 |
+
workspace, # for TMA store
|
416 |
+
m_sizes, # group sizes [G]
|
417 |
+
# problem sizes
|
418 |
+
G: tl.constexpr,
|
419 |
+
M_BUCKET: tl.constexpr,
|
420 |
+
N: tl.constexpr,
|
421 |
+
K: tl.constexpr,
|
422 |
+
# config
|
423 |
+
NUM_SMS: tl.constexpr,
|
424 |
+
USE_TMA_LOAD: tl.constexpr,
|
425 |
+
USE_TMA_STORE: tl.constexpr,
|
426 |
+
TMA_SIZE: tl.constexpr,
|
427 |
+
# tiles
|
428 |
+
BLOCK_SIZE_M: tl.constexpr,
|
429 |
+
BLOCK_SIZE_N: tl.constexpr,
|
430 |
+
BLOCK_SIZE_K: tl.constexpr,
|
431 |
+
) -> None:
|
432 |
+
"""
|
433 |
+
TMA-optimized kernel for computing gradients with respect to input (dx).
|
434 |
+
For the forward pass Y = X @ W.T, the backward for input is:
|
435 |
+
grad_X = grad_Y @ W
|
436 |
+
|
437 |
+
This maps to [MG, N] @ [N, K] -> [MG, K]
|
438 |
+
|
439 |
+
Key differences from forward:
|
440 |
+
1. W is used directly and not transposed
|
441 |
+
2. The reduction dimension is now N (not K)
|
442 |
+
3. Output is [M, K] instead of [M, N]
|
443 |
+
"""
|
444 |
+
tbidx = tl.program_id(0) # thread block index
|
445 |
+
|
446 |
+
c_dtype = grad_input_ptr.dtype.element_ty
|
447 |
+
c_desc_ptr = workspace + (tbidx * TMA_SIZE)
|
448 |
+
|
449 |
+
M_end = 0
|
450 |
+
processed_tiles = 0
|
451 |
+
|
452 |
+
for g in range(G):
|
453 |
+
# Move down along groups - same as forward
|
454 |
+
M_start = M_end
|
455 |
+
m_size = tl.load(m_sizes + g)
|
456 |
+
M_end = M_start + m_size
|
457 |
+
|
458 |
+
if m_size > 0:
|
459 |
+
# Process this group
|
460 |
+
# tiles for this group - now producing [M, K] output
|
461 |
+
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
462 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
463 |
+
group_num_tiles = num_m_tiles * num_k_tiles
|
464 |
+
|
465 |
+
# TMA Store prep for [M, K] output
|
466 |
+
tl.extra.cuda.experimental_device_tensormap_create2d(
|
467 |
+
desc_ptr=c_desc_ptr,
|
468 |
+
global_address=grad_input_ptr + M_start * K,
|
469 |
+
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
470 |
+
global_size=[m_size, K],
|
471 |
+
element_ty=c_dtype,
|
472 |
+
)
|
473 |
+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
474 |
+
|
475 |
+
while tbidx >= processed_tiles and tbidx < (
|
476 |
+
processed_tiles + group_num_tiles
|
477 |
+
):
|
478 |
+
group_index = tbidx - processed_tiles
|
479 |
+
|
480 |
+
# Different tiling scheme for [M, K] output
|
481 |
+
tile_m_index = group_index % num_m_tiles
|
482 |
+
tile_k_index = group_index // num_m_tiles
|
483 |
+
|
484 |
+
# for grad_input block [M, K]
|
485 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
486 |
+
|
487 |
+
# Position in full matrix
|
488 |
+
m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
|
489 |
+
k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
490 |
+
|
491 |
+
# reduce along N dimension (instead of K in forward)
|
492 |
+
for n_offset in range(0, N, BLOCK_SIZE_N):
|
493 |
+
# grad_output block [M, N]
|
494 |
+
grad_output = tl._experimental_descriptor_load(
|
495 |
+
grad_output_desc_ptr,
|
496 |
+
[m_offset, n_offset],
|
497 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
498 |
+
c_dtype,
|
499 |
+
)
|
500 |
+
|
501 |
+
# weight block [N, K] - no transpose needed
|
502 |
+
w = tl._experimental_descriptor_load(
|
503 |
+
w_desc_ptr,
|
504 |
+
[n_offset, k_offset],
|
505 |
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
506 |
+
c_dtype,
|
507 |
+
)
|
508 |
+
|
509 |
+
# grad_x = grad_output @ w
|
510 |
+
# reducing along N dimension
|
511 |
+
accumulator += tl.dot(grad_output, w)
|
512 |
+
|
513 |
+
# Store using TMA
|
514 |
+
m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
515 |
+
# k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
|
516 |
+
|
517 |
+
tl._experimental_descriptor_store(
|
518 |
+
c_desc_ptr,
|
519 |
+
accumulator.to(c_dtype),
|
520 |
+
[m_offset, k_offset],
|
521 |
+
)
|
522 |
+
|
523 |
+
# Move to the next tile
|
524 |
+
tbidx += NUM_SMS
|
525 |
+
|
526 |
+
# Update the total tiles count for the next group
|
527 |
+
processed_tiles += group_num_tiles
|
528 |
+
|
529 |
+
|
530 |
+
# ---- dw flat linear indexed ----
|
531 |
+
|
532 |
+
|
533 |
+
@triton.autotune(
|
534 |
+
configs=_NV_CONFIGS,
|
535 |
+
key=["G", "M_BUCKET", "N", "K"],
|
536 |
+
prune_configs_by={"early_config_prune": early_config_prune},
|
537 |
+
)
|
538 |
+
@triton.jit
|
539 |
+
def _kernel_mg_dw_tma(
|
540 |
+
x_desc_ptr, # input descriptor [M_total, K]
|
541 |
+
grad_output_desc_ptr, # grad_output descriptor [M_total, N]
|
542 |
+
grad_weight_ptr, # output grad_w [N, K]
|
543 |
+
workspace, # workspace for TMA store
|
544 |
+
m_sizes, # group sizes [G]
|
545 |
+
# problem sizes
|
546 |
+
G: tl.constexpr,
|
547 |
+
M_BUCKET: tl.constexpr,
|
548 |
+
N: tl.constexpr,
|
549 |
+
K: tl.constexpr,
|
550 |
+
# config
|
551 |
+
NUM_SMS: tl.constexpr,
|
552 |
+
USE_TMA_LOAD: tl.constexpr,
|
553 |
+
USE_TMA_STORE: tl.constexpr,
|
554 |
+
TMA_SIZE: tl.constexpr,
|
555 |
+
# tiles
|
556 |
+
BLOCK_SIZE_N: tl.constexpr,
|
557 |
+
BLOCK_SIZE_K: tl.constexpr,
|
558 |
+
BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
|
559 |
+
) -> None:
|
560 |
+
"""
|
561 |
+
Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
|
562 |
+
Uses flat index structure similar to forward.
|
563 |
+
|
564 |
+
For the forward pass Y = X @ W.T,
|
565 |
+
the backward for weights is:
|
566 |
+
grad_W = grad_Y.T @ X
|
567 |
+
|
568 |
+
Where:
|
569 |
+
- grad_Y is [MG, N]
|
570 |
+
- X is [MG, K]
|
571 |
+
- grad_W is [N, K]
|
572 |
+
- we return [N,K]
|
573 |
+
"""
|
574 |
+
# Get thread block index l
|
575 |
+
tbidx = tl.program_id(0)
|
576 |
+
|
577 |
+
# Get output data type
|
578 |
+
c_dtype = grad_weight_ptr.dtype.element_ty
|
579 |
+
|
580 |
+
# Calculate number of output tiles
|
581 |
+
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
582 |
+
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
583 |
+
total_output_tiles = num_n_tiles * num_k_tiles
|
584 |
+
|
585 |
+
# Process tiles in strided manner across SMs
|
586 |
+
for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
|
587 |
+
# Calculate tile indices
|
588 |
+
tile_n_idx = tile_idx % num_n_tiles
|
589 |
+
tile_k_idx = tile_idx // num_n_tiles
|
590 |
+
|
591 |
+
# Calculate global offsets
|
592 |
+
n_offset = tile_n_idx * BLOCK_SIZE_N
|
593 |
+
k_offset = tile_k_idx * BLOCK_SIZE_K
|
594 |
+
|
595 |
+
# Initialize accumulator for this output tile [N, K]
|
596 |
+
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
597 |
+
|
598 |
+
# Process each group
|
599 |
+
M_end = 0
|
600 |
+
for g in range(G):
|
601 |
+
# Get group boundaries
|
602 |
+
M_start = M_end
|
603 |
+
m_size = tl.load(m_sizes + g)
|
604 |
+
M_end = M_start + m_size
|
605 |
+
|
606 |
+
# Only process if group is non-empty
|
607 |
+
if m_size > 0:
|
608 |
+
# Process this group in chunks along the M dimension
|
609 |
+
for m_offset in range(0, m_size, BLOCK_SIZE_M):
|
610 |
+
# Calculate actual block size (handling boundary)
|
611 |
+
m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
|
612 |
+
|
613 |
+
# Only process if we have actual work to do
|
614 |
+
if m_block_size > 0:
|
615 |
+
# Global offset for this chunk
|
616 |
+
m_global_offset = M_start + m_offset
|
617 |
+
|
618 |
+
if USE_TMA_LOAD:
|
619 |
+
# Load input chunk [M_chunk, K] using TMA
|
620 |
+
x_block = tl._experimental_descriptor_load(
|
621 |
+
x_desc_ptr,
|
622 |
+
[m_global_offset, k_offset],
|
623 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
624 |
+
c_dtype,
|
625 |
+
)
|
626 |
+
|
627 |
+
# Load grad_output chunk [M_chunk, N] using TMA
|
628 |
+
grad_output_block = tl._experimental_descriptor_load(
|
629 |
+
grad_output_desc_ptr,
|
630 |
+
[m_global_offset, n_offset],
|
631 |
+
[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
632 |
+
c_dtype,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Apply masks for valid regions
|
636 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
637 |
+
m_mask = offs_m < m_block_size
|
638 |
+
|
639 |
+
# Zero out invalid elements
|
640 |
+
x_block = tl.where(m_mask[:, None], x_block, 0.0)
|
641 |
+
grad_output_block = tl.where(
|
642 |
+
m_mask[:, None], grad_output_block, 0.0
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
# Manual load with bounds checking
|
646 |
+
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
647 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
648 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
649 |
+
|
650 |
+
# Create masks
|
651 |
+
m_mask = offs_m < m_block_size
|
652 |
+
n_mask = offs_n < N - n_offset
|
653 |
+
k_mask = offs_k < K - k_offset
|
654 |
+
|
655 |
+
# Combined masks
|
656 |
+
mk_mask = m_mask[:, None] & k_mask[None, :]
|
657 |
+
mn_mask = m_mask[:, None] & n_mask[None, :]
|
658 |
+
|
659 |
+
# Global offsets for loading
|
660 |
+
m_global_offs = m_global_offset + offs_m
|
661 |
+
|
662 |
+
# Load x block [M_chunk, K]
|
663 |
+
x_block = tl.load(
|
664 |
+
x_desc_ptr
|
665 |
+
+ m_global_offs[:, None] * K
|
666 |
+
+ (k_offset + offs_k)[None, :],
|
667 |
+
mask=mk_mask,
|
668 |
+
other=0.0,
|
669 |
+
)
|
670 |
+
|
671 |
+
# Load grad_output block [M_chunk, N]
|
672 |
+
grad_output_block = tl.load(
|
673 |
+
grad_output_desc_ptr
|
674 |
+
+ m_global_offs[:, None] * N
|
675 |
+
+ (n_offset + offs_n)[None, :],
|
676 |
+
mask=mn_mask,
|
677 |
+
other=0.0,
|
678 |
+
)
|
679 |
+
|
680 |
+
# Compute partial contribution: grad_W += grad_Y.T @ X
|
681 |
+
# transpose grad_output for the matmul
|
682 |
+
contribution = tl.dot(
|
683 |
+
grad_output_block.to(tl.float32).T, # [N, M_chunk]
|
684 |
+
x_block.to(tl.float32), # [M_chunk, K]
|
685 |
+
)
|
686 |
+
|
687 |
+
# Accumulate
|
688 |
+
accumulator += contribution
|
689 |
+
|
690 |
+
# Store the result
|
691 |
+
if USE_TMA_STORE:
|
692 |
+
# Store using TMA
|
693 |
+
tl._experimental_descriptor_store(
|
694 |
+
workspace, # TMA store descriptor
|
695 |
+
accumulator.to(c_dtype),
|
696 |
+
[n_offset, k_offset],
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
# Manual store with bounds checking
|
700 |
+
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
701 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
702 |
+
|
703 |
+
# Create masks for bounds checking
|
704 |
+
n_mask = offs_n < N - n_offset
|
705 |
+
k_mask = offs_k < K - k_offset
|
706 |
+
output_mask = n_mask[:, None] & k_mask[None, :]
|
707 |
+
|
708 |
+
# Store the result
|
709 |
+
tl.store(
|
710 |
+
grad_weight_ptr
|
711 |
+
+ (n_offset + offs_n)[:, None] * K
|
712 |
+
+ (k_offset + offs_k)[None, :],
|
713 |
+
accumulator.to(c_dtype),
|
714 |
+
mask=output_mask,
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
# ======== End Triton kernels ========
|
719 |
+
|
720 |
+
# ======== Triton wrapper functions ========
|
721 |
+
|
722 |
+
# ----- main forward pass wrapper -----
|
723 |
+
|
724 |
+
|
725 |
+
def grouped_gemm_forward(
|
726 |
+
x: torch.Tensor,
|
727 |
+
w: torch.Tensor,
|
728 |
+
m_sizes: torch.Tensor,
|
729 |
+
tma_size: int = 128,
|
730 |
+
) -> torch.Tensor:
|
731 |
+
"""
|
732 |
+
M*G style grouped GEMM with TMA and Float8 support.
|
733 |
+
# Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
|
734 |
+
|
735 |
+
"""
|
736 |
+
if not CudaUtils.verify_tma():
|
737 |
+
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
738 |
+
|
739 |
+
G = m_sizes.shape[0]
|
740 |
+
|
741 |
+
assert x.is_contiguous()
|
742 |
+
assert w.is_contiguous()
|
743 |
+
assert m_sizes.is_contiguous()
|
744 |
+
|
745 |
+
# Total input size is now [M_total, K] where M_total is the sum of all group sizes
|
746 |
+
M_total, K = x.shape
|
747 |
+
N = w.shape[0] # N is now the same for all groups
|
748 |
+
|
749 |
+
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
750 |
+
|
751 |
+
# Verify that all group sizes are multiples of ALIGN_SIZE_M
|
752 |
+
# This check is commented out because it will involve a GPU-CPU sync
|
753 |
+
# assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
|
754 |
+
|
755 |
+
# Create output tensor with correct shape [M_total, N]
|
756 |
+
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
757 |
+
|
758 |
+
if M_total == 0:
|
759 |
+
return y
|
760 |
+
|
761 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
762 |
+
USE_TMA_LOAD = True
|
763 |
+
USE_TMA_STORE = True
|
764 |
+
USE_EPILOGUE_SUBTILING = False
|
765 |
+
|
766 |
+
# TMA descriptor helper
|
767 |
+
desc_helper = None
|
768 |
+
desc_x = x
|
769 |
+
desc_w = w
|
770 |
+
workspace = None
|
771 |
+
|
772 |
+
if USE_TMA_LOAD:
|
773 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
774 |
+
desc_helper.init_tma_descriptor("x")
|
775 |
+
desc_helper.init_tma_descriptor("w")
|
776 |
+
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
777 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
778 |
+
|
779 |
+
if USE_TMA_STORE:
|
780 |
+
workspace = torch.empty(
|
781 |
+
NUM_SMS * desc_helper.tma_size,
|
782 |
+
device=x.device,
|
783 |
+
dtype=torch.uint8,
|
784 |
+
)
|
785 |
+
|
786 |
+
def grid(META):
|
787 |
+
if USE_TMA_LOAD:
|
788 |
+
nonlocal desc_helper
|
789 |
+
desc_helper.fill_2d_tma_descriptor(
|
790 |
+
"x",
|
791 |
+
x.data_ptr(),
|
792 |
+
M_total,
|
793 |
+
K,
|
794 |
+
META["BLOCK_SIZE_M"],
|
795 |
+
META["BLOCK_SIZE_K"],
|
796 |
+
x.element_size(),
|
797 |
+
)
|
798 |
+
|
799 |
+
desc_helper.fill_2d_tma_descriptor(
|
800 |
+
"w",
|
801 |
+
w.data_ptr(),
|
802 |
+
N,
|
803 |
+
K,
|
804 |
+
META["BLOCK_SIZE_N"],
|
805 |
+
META["BLOCK_SIZE_K"],
|
806 |
+
w.element_size(),
|
807 |
+
)
|
808 |
+
return (NUM_SMS,)
|
809 |
+
|
810 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
811 |
+
|
812 |
+
_kernel_mg_forward_hopper[grid](
|
813 |
+
desc_x,
|
814 |
+
desc_w,
|
815 |
+
y,
|
816 |
+
workspace,
|
817 |
+
m_sizes,
|
818 |
+
G,
|
819 |
+
M_BUCKET,
|
820 |
+
N,
|
821 |
+
K,
|
822 |
+
NUM_SMS,
|
823 |
+
TMA_SIZE=tma_size,
|
824 |
+
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
825 |
+
)
|
826 |
+
|
827 |
+
return y
|
828 |
+
|
829 |
+
|
830 |
+
# ======== Improved Backward =============
|
831 |
+
def grouped_gemm_backward(
|
832 |
+
grad_output: torch.Tensor,
|
833 |
+
x: torch.Tensor,
|
834 |
+
w: torch.Tensor,
|
835 |
+
m_sizes: torch.Tensor,
|
836 |
+
use_tma: bool = True,
|
837 |
+
tma_size: int = 128,
|
838 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
839 |
+
"""
|
840 |
+
Unified backward pass for grouped GeMM with M*G grouping.
|
841 |
+
Uses optimized TMA-based implementations for both dx and dw when available.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
grad_output: Gradient of output, shape [M_total, N]
|
845 |
+
x: Input tensor from forward pass, shape [M_total, K]
|
846 |
+
w: Weight tensor from forward pass, shape [N, K]
|
847 |
+
m_sizes: Group sizes tensor, shape [G]
|
848 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
849 |
+
tma_size: Size of TMA descriptor in bytes
|
850 |
+
|
851 |
+
|
852 |
+
Returns:
|
853 |
+
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
854 |
+
"""
|
855 |
+
logging.info("Starting unified grouped_gemm_backward")
|
856 |
+
|
857 |
+
# do this once, seems expensive
|
858 |
+
NUM_SMS = CudaUtils.get_num_sms()
|
859 |
+
|
860 |
+
# Basic validation
|
861 |
+
G = m_sizes.shape[0]
|
862 |
+
M_total, K_x = x.shape
|
863 |
+
M_grad, N = grad_output.shape
|
864 |
+
N_w, K_w = w.shape
|
865 |
+
|
866 |
+
# Check dimensions
|
867 |
+
if K_x != K_w:
|
868 |
+
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
869 |
+
if M_total != M_grad:
|
870 |
+
raise ValueError(
|
871 |
+
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
872 |
+
)
|
873 |
+
|
874 |
+
# Check total M matches sum of group sizes
|
875 |
+
sum_m_sizes = m_sizes.sum().item()
|
876 |
+
if M_total != sum_m_sizes:
|
877 |
+
raise ValueError(
|
878 |
+
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
879 |
+
)
|
880 |
+
|
881 |
+
# Make sure inputs are contiguous
|
882 |
+
grad_output = grad_output.contiguous()
|
883 |
+
x = x.contiguous()
|
884 |
+
w = w.contiguous()
|
885 |
+
m_sizes = m_sizes.contiguous()
|
886 |
+
|
887 |
+
# Check TMA support
|
888 |
+
can_use_tma = use_tma and CudaUtils.verify_tma()
|
889 |
+
if use_tma and not can_use_tma:
|
890 |
+
logging.info("TMA requested but not supported on this device")
|
891 |
+
use_tma = False
|
892 |
+
|
893 |
+
# Compute grad_x using flat linear implementation
|
894 |
+
try:
|
895 |
+
logging.info(f"Computing grad_x with flat linear kernel")
|
896 |
+
|
897 |
+
# Use TMA-optimized implementation
|
898 |
+
grad_x = grouped_gemm_dx_tma(
|
899 |
+
grad_output=grad_output,
|
900 |
+
w=w,
|
901 |
+
m_sizes=m_sizes,
|
902 |
+
num_sms=NUM_SMS,
|
903 |
+
tma_size=tma_size,
|
904 |
+
)
|
905 |
+
|
906 |
+
except Exception as e:
|
907 |
+
logging.error(f"Error in grad_x computation: {e}")
|
908 |
+
raise
|
909 |
+
|
910 |
+
# Compute grad_w using flat linear style implementation
|
911 |
+
try:
|
912 |
+
logging.info(f"Computing grad_w with flat linear kernel")
|
913 |
+
|
914 |
+
grad_w = grouped_gemm_dw_tma(
|
915 |
+
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
916 |
+
)
|
917 |
+
except Exception as e:
|
918 |
+
logging.error(f"Error in grad_w computation: {e}")
|
919 |
+
raise
|
920 |
+
|
921 |
+
return grad_x, grad_w
|
922 |
+
|
923 |
+
|
924 |
+
# ----- dx backward pass wrapper -----
|
925 |
+
|
926 |
+
|
927 |
+
def grouped_gemm_dx_tma(
|
928 |
+
grad_output: torch.Tensor,
|
929 |
+
w: torch.Tensor,
|
930 |
+
m_sizes: torch.Tensor,
|
931 |
+
num_sms: int = 132,
|
932 |
+
tma_size: int = 128,
|
933 |
+
) -> torch.Tensor:
|
934 |
+
"""
|
935 |
+
Optimized backward pass wrapper for computing gradient with respect to input (dx)
|
936 |
+
using TMA patterns similar to the forward pass.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
grad_output: Gradient of output, shape [M_total, N]
|
940 |
+
w: Weight tensor, shape [N, K]
|
941 |
+
m_sizes: Group sizes tensor, shape [G]
|
942 |
+
tma_size: Size of TMA descriptor
|
943 |
+
# using_fp8: Whether to use FP8 quantization
|
944 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
945 |
+
# w_scale: Scale for w in FP8 mode
|
946 |
+
|
947 |
+
Returns:
|
948 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
949 |
+
"""
|
950 |
+
"""
|
951 |
+
Optimized backward pass for computing gradient with respect to input (dx)
|
952 |
+
using TMA patterns similar to the forward pass.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
grad_output: Gradient of output, shape [M_total, N]
|
956 |
+
w: Weight tensor, shape [N, K]
|
957 |
+
m_sizes: Group sizes tensor, shape [G]
|
958 |
+
tma_size: Size of TMA descriptor
|
959 |
+
using_fp8: Whether to use FP8 quantization
|
960 |
+
# grad_output_scale: Scale for grad_output in FP8 mode
|
961 |
+
# w_scale: Scale for w in FP8 mode
|
962 |
+
|
963 |
+
Returns:
|
964 |
+
grad_x: Gradient with respect to x, shape [M_total, K]
|
965 |
+
"""
|
966 |
+
if not CudaUtils.verify_tma():
|
967 |
+
raise NotImplementedError("Optimized dx computation requires TMA support")
|
968 |
+
|
969 |
+
G = m_sizes.shape[0]
|
970 |
+
|
971 |
+
assert grad_output.is_contiguous()
|
972 |
+
assert w.is_contiguous()
|
973 |
+
assert m_sizes.is_contiguous()
|
974 |
+
|
975 |
+
M_total, N_grad = grad_output.shape
|
976 |
+
N_w, K = w.shape
|
977 |
+
|
978 |
+
# Check dimensions
|
979 |
+
assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
|
980 |
+
|
981 |
+
# Verify that the sum of m_sizes matches M_total
|
982 |
+
sum_m_sizes = m_sizes.sum().item()
|
983 |
+
assert (
|
984 |
+
M_total == sum_m_sizes
|
985 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
986 |
+
|
987 |
+
# Create output tensor (grad_x) with shape [M_total, K]
|
988 |
+
grad_x = torch.empty(
|
989 |
+
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
990 |
+
)
|
991 |
+
|
992 |
+
NUM_SMS = num_sms # CudaUtils.get_num_sms()
|
993 |
+
USE_TMA_LOAD = True
|
994 |
+
USE_TMA_STORE = True
|
995 |
+
|
996 |
+
# Set up TMA descriptors
|
997 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
998 |
+
desc_helper.init_tma_descriptor("grad_output")
|
999 |
+
desc_helper.init_tma_descriptor("w")
|
1000 |
+
desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
|
1001 |
+
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
1002 |
+
|
1003 |
+
# Allocate workspace for TMA store
|
1004 |
+
workspace = torch.empty(
|
1005 |
+
NUM_SMS * desc_helper.tma_size,
|
1006 |
+
device=grad_output.device,
|
1007 |
+
dtype=torch.uint8,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
def grid(META):
|
1011 |
+
# Fill TMA descriptors with appropriate dimensions
|
1012 |
+
desc_helper.fill_2d_tma_descriptor(
|
1013 |
+
"grad_output",
|
1014 |
+
grad_output.data_ptr(),
|
1015 |
+
M_total,
|
1016 |
+
N_grad,
|
1017 |
+
META["BLOCK_SIZE_M"],
|
1018 |
+
META["BLOCK_SIZE_N"],
|
1019 |
+
grad_output.element_size(),
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
desc_helper.fill_2d_tma_descriptor(
|
1023 |
+
"w",
|
1024 |
+
w.data_ptr(),
|
1025 |
+
N_w,
|
1026 |
+
K,
|
1027 |
+
META["BLOCK_SIZE_N"],
|
1028 |
+
META["BLOCK_SIZE_K"],
|
1029 |
+
w.element_size(),
|
1030 |
+
)
|
1031 |
+
return (NUM_SMS,)
|
1032 |
+
|
1033 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1034 |
+
|
1035 |
+
# Launch the flat linear kernel for computing grad_x
|
1036 |
+
_kernel_mg_dx_tma[grid](
|
1037 |
+
desc_grad_output,
|
1038 |
+
desc_w,
|
1039 |
+
grad_x,
|
1040 |
+
workspace,
|
1041 |
+
m_sizes,
|
1042 |
+
G,
|
1043 |
+
M_BUCKET,
|
1044 |
+
N_grad, # N dimension is now the reduction dimension
|
1045 |
+
K,
|
1046 |
+
NUM_SMS,
|
1047 |
+
USE_TMA_LOAD,
|
1048 |
+
USE_TMA_STORE,
|
1049 |
+
TMA_SIZE=tma_size,
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
return grad_x
|
1053 |
+
|
1054 |
+
|
1055 |
+
# ======== dw wrapper function ==========
|
1056 |
+
|
1057 |
+
|
1058 |
+
def grouped_gemm_dw_tma(
|
1059 |
+
x: torch.Tensor,
|
1060 |
+
grad_output: torch.Tensor,
|
1061 |
+
m_sizes: torch.Tensor,
|
1062 |
+
num_sms: int = 132,
|
1063 |
+
tma_size: int = 128,
|
1064 |
+
) -> torch.Tensor:
|
1065 |
+
"""
|
1066 |
+
Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
|
1067 |
+
For the forward pass Y = X @ W.T, the backward for weights is:
|
1068 |
+
grad_W = grad_Y.T @ X
|
1069 |
+
|
1070 |
+
Args:
|
1071 |
+
x: Input tensor, shape [M_total, K]
|
1072 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1073 |
+
m_sizes: Group sizes tensor, shape [G]
|
1074 |
+
tma_size: Size of TMA descriptor in bytes
|
1075 |
+
|
1076 |
+
|
1077 |
+
Returns:
|
1078 |
+
grad_w: Gradient with respect to weights, shape [N, K]
|
1079 |
+
"""
|
1080 |
+
# Check TMA support
|
1081 |
+
has_tma_support = CudaUtils.verify_tma()
|
1082 |
+
|
1083 |
+
# Get group count
|
1084 |
+
G = m_sizes.shape[0]
|
1085 |
+
|
1086 |
+
# Ensure contiguous tensors
|
1087 |
+
x = x.contiguous()
|
1088 |
+
grad_output = grad_output.contiguous()
|
1089 |
+
m_sizes = m_sizes.contiguous()
|
1090 |
+
|
1091 |
+
# Get dimensions
|
1092 |
+
M_total, K_x = x.shape
|
1093 |
+
M_grad, N = grad_output.shape
|
1094 |
+
|
1095 |
+
# Check dimensions
|
1096 |
+
assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
|
1097 |
+
|
1098 |
+
# Verify that the sum of m_sizes matches M_total
|
1099 |
+
sum_m_sizes = m_sizes.sum().item()
|
1100 |
+
assert (
|
1101 |
+
sum_m_sizes == M_total
|
1102 |
+
), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
1103 |
+
|
1104 |
+
# Create output tensor (grad_w) with shape [N, K]
|
1105 |
+
grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
|
1106 |
+
|
1107 |
+
NUM_SMS = num_sms
|
1108 |
+
|
1109 |
+
# TODO - hardcoded for now...but should set TMA flags based on hardware support
|
1110 |
+
USE_TMA_LOAD = True # has_tma_support
|
1111 |
+
USE_TMA_STORE = True # has_tma_support
|
1112 |
+
|
1113 |
+
# Set up TMA descriptors or direct pointers
|
1114 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1115 |
+
desc_helper = TmaDescriptorHelper(tma_size=tma_size)
|
1116 |
+
|
1117 |
+
if USE_TMA_LOAD:
|
1118 |
+
desc_helper.init_tma_descriptor("x")
|
1119 |
+
desc_helper.init_tma_descriptor("grad_output")
|
1120 |
+
x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
|
1121 |
+
grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
|
1122 |
+
"grad_output"
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
x_desc = x
|
1126 |
+
grad_output_desc = grad_output
|
1127 |
+
|
1128 |
+
if USE_TMA_STORE:
|
1129 |
+
desc_helper.init_tma_descriptor("grad_w")
|
1130 |
+
workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
|
1131 |
+
else:
|
1132 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1133 |
+
else:
|
1134 |
+
# If not using TMA, just use the tensors directly
|
1135 |
+
x_desc = x
|
1136 |
+
grad_output_desc = grad_output
|
1137 |
+
workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
|
1138 |
+
|
1139 |
+
# M_BUCKET for grid size
|
1140 |
+
M_BUCKET = triton.next_power_of_2(M_total)
|
1141 |
+
|
1142 |
+
# Define grid for kernel launch
|
1143 |
+
def grid(META):
|
1144 |
+
if USE_TMA_LOAD or USE_TMA_STORE:
|
1145 |
+
|
1146 |
+
if USE_TMA_LOAD:
|
1147 |
+
desc_helper.fill_2d_tma_descriptor(
|
1148 |
+
"x",
|
1149 |
+
x.data_ptr(),
|
1150 |
+
M_total,
|
1151 |
+
K_x,
|
1152 |
+
META["BLOCK_SIZE_M"],
|
1153 |
+
META["BLOCK_SIZE_K"],
|
1154 |
+
x.element_size(),
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
desc_helper.fill_2d_tma_descriptor(
|
1158 |
+
"grad_output",
|
1159 |
+
grad_output.data_ptr(),
|
1160 |
+
M_total,
|
1161 |
+
N,
|
1162 |
+
META["BLOCK_SIZE_M"],
|
1163 |
+
META["BLOCK_SIZE_N"],
|
1164 |
+
grad_output.element_size(),
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
if USE_TMA_STORE:
|
1168 |
+
desc_helper.fill_2d_tma_descriptor(
|
1169 |
+
"grad_w",
|
1170 |
+
grad_w.data_ptr(),
|
1171 |
+
N,
|
1172 |
+
K_x,
|
1173 |
+
META["BLOCK_SIZE_N"],
|
1174 |
+
META["BLOCK_SIZE_K"],
|
1175 |
+
grad_w.element_size(),
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
# Return grid size - one block per SM for balanced work distribution
|
1179 |
+
return (NUM_SMS,)
|
1180 |
+
|
1181 |
+
# Launch the optimized kernel
|
1182 |
+
_kernel_mg_dw_tma[grid](
|
1183 |
+
x_desc,
|
1184 |
+
grad_output_desc,
|
1185 |
+
grad_w,
|
1186 |
+
workspace,
|
1187 |
+
m_sizes,
|
1188 |
+
G,
|
1189 |
+
M_BUCKET,
|
1190 |
+
N,
|
1191 |
+
K_x,
|
1192 |
+
NUM_SMS,
|
1193 |
+
USE_TMA_LOAD,
|
1194 |
+
USE_TMA_STORE,
|
1195 |
+
TMA_SIZE=tma_size,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
return grad_w
|
1199 |
+
|
1200 |
+
|
1201 |
+
# ======== End Backwards Wrapper Functions =============
|
1202 |
+
|
1203 |
+
# ======== PyTorch wrapper functions ========
|
1204 |
+
|
1205 |
+
|
1206 |
+
class GroupedGEMM_mg(torch.autograd.Function):
|
1207 |
+
"""
|
1208 |
+
Autograd function for GroupedGEMM with M*G grouping.
|
1209 |
+
Supports both standard and FP8 quantized operations.
|
1210 |
+
"""
|
1211 |
+
|
1212 |
+
@staticmethod
|
1213 |
+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
|
1214 |
+
"""
|
1215 |
+
Forward pass of GroupedGEMM.
|
1216 |
+
|
1217 |
+
Args:
|
1218 |
+
x: Input tensor, shape [M_total, K]
|
1219 |
+
w: Weight tensor, shape [N, K]
|
1220 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1221 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1222 |
+
tma_size: Size of TMA descriptor in bytes
|
1223 |
+
using_fp8: Whether to use FP8 quantization
|
1224 |
+
|
1225 |
+
Returns:
|
1226 |
+
Output tensor, shape [M_total, N]
|
1227 |
+
"""
|
1228 |
+
|
1229 |
+
# Use regular forward without quantization
|
1230 |
+
output = grouped_gemm_forward(
|
1231 |
+
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
# Save inputs and parameters for backward pass
|
1235 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1236 |
+
ctx.use_tma = use_tma
|
1237 |
+
ctx.tma_size = tma_size
|
1238 |
+
|
1239 |
+
ctx.save_for_backward(x, w, m_sizes)
|
1240 |
+
|
1241 |
+
return output
|
1242 |
+
|
1243 |
+
@staticmethod
|
1244 |
+
def backward(ctx, grad_output):
|
1245 |
+
"""
|
1246 |
+
Backward pass of M*G GroupedGEMM.
|
1247 |
+
|
1248 |
+
Args:
|
1249 |
+
grad_output: Gradient of output, shape [M_total, N]
|
1250 |
+
|
1251 |
+
Returns:
|
1252 |
+
Tuple of gradients:
|
1253 |
+
- grad_x: Gradient with respect to x, shape [M_total, K]
|
1254 |
+
- grad_w: Gradient with respect to w, shape [N, K]
|
1255 |
+
- None: Gradient with respect to m_sizes (not differentiable)
|
1256 |
+
- None: Gradient with respect to use_tma (not differentiable)
|
1257 |
+
- None: Gradient with respect to tma_size (not differentiable)
|
1258 |
+
|
1259 |
+
"""
|
1260 |
+
# Retrieve saved tensors and parameters
|
1261 |
+
|
1262 |
+
x, w, m_sizes = ctx.saved_tensors
|
1263 |
+
|
1264 |
+
use_tma = ctx.use_tma
|
1265 |
+
tma_size = ctx.tma_size
|
1266 |
+
|
1267 |
+
# Compute gradients using the unified implementation
|
1268 |
+
grad_x, grad_w = grouped_gemm_backward(
|
1269 |
+
grad_output=grad_output,
|
1270 |
+
x=x,
|
1271 |
+
w=w,
|
1272 |
+
m_sizes=m_sizes,
|
1273 |
+
use_tma=use_tma,
|
1274 |
+
tma_size=tma_size,
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
# Return gradients for all inputs (None for non-differentiable parameters)
|
1278 |
+
return grad_x, grad_w, None, None
|
1279 |
+
|
1280 |
+
|
1281 |
+
def mg_grouped_gemm(
|
1282 |
+
x: torch.Tensor,
|
1283 |
+
w: torch.Tensor,
|
1284 |
+
m_sizes: torch.Tensor,
|
1285 |
+
use_tma: bool = True,
|
1286 |
+
tma_size: int = 128,
|
1287 |
+
using_fp8: bool = False,
|
1288 |
+
) -> torch.Tensor:
|
1289 |
+
"""
|
1290 |
+
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
1291 |
+
Supports both standard precision and FP8 quantized operations.
|
1292 |
+
|
1293 |
+
Args:
|
1294 |
+
x: Input tensor, shape [M_total, K]
|
1295 |
+
w: Weight tensor, shape [N, K]
|
1296 |
+
m_sizes: Tensor of shape [G] containing the size of each group
|
1297 |
+
use_tma: Whether to try using TMA acceleration (if available)
|
1298 |
+
tma_size: Size of TMA descriptor in bytes
|
1299 |
+
using_fp8: Whether to use FP8 quantization
|
1300 |
+
|
1301 |
+
Returns:
|
1302 |
+
Output tensor, shape [M_total, N]
|
1303 |
+
"""
|
1304 |
+
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# pyre-unsafe
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def compute_reference_forward(x, w, m_sizes):
|
20 |
+
"""
|
21 |
+
Compute reference forward pass using PyTorch operations.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
25 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
26 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: Reference output tensor of shape (M, N)
|
30 |
+
"""
|
31 |
+
result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
|
32 |
+
|
33 |
+
m_start = 0
|
34 |
+
for g in range(len(m_sizes)):
|
35 |
+
m_size = m_sizes[g].item()
|
36 |
+
if m_size > 0:
|
37 |
+
m_end = m_start + m_size
|
38 |
+
|
39 |
+
# Extract group input
|
40 |
+
x_g = x[m_start:m_end]
|
41 |
+
|
42 |
+
# Compute group output: y_g = x_g @ w.T
|
43 |
+
y_g = torch.matmul(x_g, w.T)
|
44 |
+
|
45 |
+
# Store result
|
46 |
+
result[m_start:m_end] = y_g
|
47 |
+
|
48 |
+
# Update start index
|
49 |
+
m_start = m_end
|
50 |
+
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
def compute_reference_backward(x, w, m_sizes, grad_output):
|
55 |
+
"""
|
56 |
+
Compute reference backward pass using PyTorch autograd.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): Input tensor of shape (M, K)
|
60 |
+
w (torch.Tensor): Weight tensor of shape (N, K)
|
61 |
+
m_sizes (torch.Tensor): Group sizes tensor of shape (G)
|
62 |
+
grad_output (torch.Tensor): Gradient tensor of shape (M, N)
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
tuple: (grad_x, grad_w) gradient tensors
|
66 |
+
"""
|
67 |
+
# Create autograd-enabled copies
|
68 |
+
x_autograd = x.detach().clone().requires_grad_(True)
|
69 |
+
w_autograd = w.detach().clone().requires_grad_(True)
|
70 |
+
|
71 |
+
# Compute forward pass
|
72 |
+
output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
|
73 |
+
|
74 |
+
# Backpropagate
|
75 |
+
output.backward(grad_output)
|
76 |
+
|
77 |
+
return x_autograd.grad, w_autograd.grad
|
78 |
+
|
79 |
+
|
80 |
+
def analyze_tensor_differences(actual, expected, name):
|
81 |
+
"""
|
82 |
+
Analyze differences between actual and expected tensors.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
actual (torch.Tensor): Actual tensor
|
86 |
+
expected (torch.Tensor): Expected tensor
|
87 |
+
name (str): Name of the tensor for logging
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
bool: True if tensors are close enough
|
91 |
+
"""
|
92 |
+
rtol = 0.5 # Relative tolerance for float16
|
93 |
+
atol = 0.5 # Absolute tolerance for float16
|
94 |
+
|
95 |
+
# Analyze differences
|
96 |
+
diff = (actual - expected).abs()
|
97 |
+
max_idx = diff.argmax().item()
|
98 |
+
idx = np.unravel_index(max_idx, actual.shape)
|
99 |
+
max_diff = diff.max().item()
|
100 |
+
|
101 |
+
logging.info(f"Largest {name} difference: {max_diff} at {idx}")
|
102 |
+
logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
|
103 |
+
|
104 |
+
is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
|
105 |
+
|
106 |
+
if is_close:
|
107 |
+
logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
|
108 |
+
else:
|
109 |
+
logging.error(f"✗ FAILURE: {name} mismatch detected")
|
110 |
+
|
111 |
+
# Count zeros
|
112 |
+
zeros_actual = (actual == 0).sum().item()
|
113 |
+
zeros_expected = (expected == 0).sum().item()
|
114 |
+
logging.info(
|
115 |
+
f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
|
116 |
+
)
|
117 |
+
logging.info(
|
118 |
+
f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Check for NaNs
|
122 |
+
nan_actual = torch.isnan(actual).sum().item()
|
123 |
+
if nan_actual > 0:
|
124 |
+
logging.error(f"NaN values detected in {name}: {nan_actual}")
|
125 |
+
|
126 |
+
return is_close
|
torchtitan/experiments/llama4/infra/parallelize_llama.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.distributed.device_mesh import DeviceMesh
|
11 |
+
|
12 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
13 |
+
from torchtitan.distributed import ParallelDims
|
14 |
+
|
15 |
+
from torchtitan.models.llama3.parallelize_llama import (
|
16 |
+
apply_ac,
|
17 |
+
apply_compile,
|
18 |
+
apply_ddp,
|
19 |
+
apply_fsdp,
|
20 |
+
apply_tp,
|
21 |
+
)
|
22 |
+
from torchtitan.tools.logging import logger
|
23 |
+
|
24 |
+
|
25 |
+
def parallelize_llama(
|
26 |
+
model: nn.Module,
|
27 |
+
world_mesh: DeviceMesh,
|
28 |
+
parallel_dims: ParallelDims,
|
29 |
+
job_config: JobConfig,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
33 |
+
parallelism to the model.
|
34 |
+
|
35 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
36 |
+
the model must fit on GPU or CPU memory.
|
37 |
+
"""
|
38 |
+
|
39 |
+
if parallel_dims.tp_enabled:
|
40 |
+
if (
|
41 |
+
job_config.parallelism.enable_async_tensor_parallel
|
42 |
+
and not job_config.training.compile
|
43 |
+
):
|
44 |
+
raise RuntimeError("Async TP requires --training.compile")
|
45 |
+
|
46 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
47 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
48 |
+
"rowwise",
|
49 |
+
"rowwise_with_gw_hp",
|
50 |
+
)
|
51 |
+
|
52 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
53 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
54 |
+
# all-gather happens in high precision.
|
55 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
56 |
+
|
57 |
+
apply_tp(
|
58 |
+
model,
|
59 |
+
world_mesh["tp"],
|
60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
61 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
62 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
63 |
+
)
|
64 |
+
|
65 |
+
apply_moe_tp(model, world_mesh["tp"])
|
66 |
+
|
67 |
+
if job_config.activation_checkpoint.mode != "none":
|
68 |
+
if (
|
69 |
+
job_config.activation_checkpoint.mode == "selective"
|
70 |
+
and job_config.model.use_flex_attn
|
71 |
+
):
|
72 |
+
raise ValueError(
|
73 |
+
"FlexAttention is not compatible with selective AC yet. "
|
74 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
75 |
+
)
|
76 |
+
apply_ac(model, job_config.activation_checkpoint)
|
77 |
+
|
78 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
79 |
+
if job_config.training.compile:
|
80 |
+
apply_compile(model)
|
81 |
+
|
82 |
+
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
|
83 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
84 |
+
|
85 |
+
if (
|
86 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
87 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
88 |
+
if parallel_dims.dp_replicate_enabled:
|
89 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
90 |
+
else:
|
91 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
92 |
+
|
93 |
+
apply_fsdp(
|
94 |
+
model,
|
95 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
96 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
97 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
98 |
+
pp_enabled=parallel_dims.pp_enabled,
|
99 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
100 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
101 |
+
)
|
102 |
+
|
103 |
+
if parallel_dims.dp_replicate_enabled:
|
104 |
+
logger.info("Applied HSDP to the model")
|
105 |
+
else:
|
106 |
+
logger.info("Applied FSDP to the model")
|
107 |
+
|
108 |
+
if parallel_dims.cp_enabled:
|
109 |
+
logger.info("Applied Context Parallel to the model")
|
110 |
+
|
111 |
+
if job_config.training.enable_cpu_offload:
|
112 |
+
logger.info("Applied CPU Offloading to the model")
|
113 |
+
elif parallel_dims.dp_replicate_enabled:
|
114 |
+
if world_mesh.ndim > 1:
|
115 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
116 |
+
apply_ddp(
|
117 |
+
model,
|
118 |
+
world_mesh,
|
119 |
+
enable_compile=job_config.training.compile,
|
120 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
121 |
+
)
|
122 |
+
|
123 |
+
return model
|
124 |
+
|
125 |
+
|
126 |
+
def apply_moe_tp(
|
127 |
+
model: nn.Module,
|
128 |
+
tp_mesh: DeviceMesh,
|
129 |
+
):
|
130 |
+
from torch.distributed.tensor import Partial, Replicate, Shard
|
131 |
+
from torch.distributed.tensor.parallel import (
|
132 |
+
parallelize_module,
|
133 |
+
PrepareModuleInputOutput,
|
134 |
+
)
|
135 |
+
|
136 |
+
from .expert_parallel import NoParallel, TensorParallel
|
137 |
+
|
138 |
+
for _, transformer_block in model.layers.items():
|
139 |
+
moe_layer_plan = {
|
140 |
+
# input / output sharding on the seqlen dim
|
141 |
+
# all-gather for input, reduce-scatter for output
|
142 |
+
"moe": PrepareModuleInputOutput(
|
143 |
+
input_layouts=(Shard(1),),
|
144 |
+
desired_input_layouts=(Replicate(),),
|
145 |
+
use_local_input=True,
|
146 |
+
output_layouts=(Partial(),),
|
147 |
+
desired_output_layouts=(Shard(1),),
|
148 |
+
),
|
149 |
+
# replicate computation for the router
|
150 |
+
"moe.router.gate": NoParallel(),
|
151 |
+
# input Replicate, output Partial
|
152 |
+
"moe.experts": TensorParallel(),
|
153 |
+
"moe.shared_expert": TensorParallel(),
|
154 |
+
}
|
155 |
+
parallelize_module(
|
156 |
+
module=transformer_block,
|
157 |
+
device_mesh=tp_mesh,
|
158 |
+
parallelize_plan=moe_layer_plan,
|
159 |
+
)
|
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
torchtitan/experiments/llama4/model/model.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from torchtitan.models.attention import build_attention, init_attention_mask
|
13 |
+
from torchtitan.models.norms import build_norm
|
14 |
+
from torchtitan.protocols.train_spec import ModelProtocol
|
15 |
+
|
16 |
+
from .args import TransformerModelArgs
|
17 |
+
from .moe import MoE
|
18 |
+
|
19 |
+
|
20 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
23 |
+
|
24 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
25 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
26 |
+
The returned tensor contains complex values in complex64 data type.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dim (int): Dimension of the frequency tensor.
|
30 |
+
end (int): End index for precomputing frequencies.
|
31 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
35 |
+
"""
|
36 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
37 |
+
t = torch.arange(end, device=freqs.device)
|
38 |
+
freqs = torch.outer(t, freqs).float()
|
39 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
40 |
+
return freqs_cis
|
41 |
+
|
42 |
+
|
43 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
46 |
+
|
47 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
48 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
49 |
+
|
50 |
+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
|
51 |
+
and the first seqlen elements will be sliced, but dim must match x.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
55 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: Reshaped frequency tensor.
|
59 |
+
"""
|
60 |
+
ndim = x.ndim
|
61 |
+
assert ndim > 1
|
62 |
+
seqlen = x.shape[1]
|
63 |
+
freqs_cis = freqs_cis[0:seqlen]
|
64 |
+
assert freqs_cis.shape == (seqlen, x.shape[-1])
|
65 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
66 |
+
return freqs_cis.view(*shape)
|
67 |
+
|
68 |
+
|
69 |
+
def apply_rotary_emb(
|
70 |
+
xq: torch.Tensor,
|
71 |
+
xk: torch.Tensor,
|
72 |
+
freqs_cis: torch.Tensor,
|
73 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
74 |
+
"""
|
75 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
76 |
+
|
77 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
78 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
79 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
80 |
+
returned as real tensors.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
84 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
85 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
89 |
+
"""
|
90 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
91 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
92 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
93 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
94 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
95 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
96 |
+
|
97 |
+
|
98 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
99 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
100 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
101 |
+
if n_rep == 1:
|
102 |
+
return x
|
103 |
+
return (
|
104 |
+
torch.unsqueeze(x, dim=3)
|
105 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
106 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
class Attention(nn.Module):
|
111 |
+
"""
|
112 |
+
Multi-head attention module.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
116 |
+
|
117 |
+
Attributes:
|
118 |
+
n_kv_heads (int): Number of key and value heads.
|
119 |
+
n_heads (int): Number of query heads.
|
120 |
+
n_rep (int): Number of repetitions for local heads.
|
121 |
+
head_dim (int): Dimension size of each attention head.
|
122 |
+
wq (Linear): Linear transformation for queries.
|
123 |
+
wk (Linear): Linear transformation for keys.
|
124 |
+
wv (Linear): Linear transformation for values.
|
125 |
+
wo (Linear): Linear transformation for output.
|
126 |
+
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, model_args: TransformerModelArgs):
|
130 |
+
super().__init__()
|
131 |
+
self.n_heads = model_args.n_heads
|
132 |
+
self.n_kv_heads = (
|
133 |
+
model_args.n_heads
|
134 |
+
if model_args.n_kv_heads is None
|
135 |
+
else model_args.n_kv_heads
|
136 |
+
)
|
137 |
+
self.n_rep = self.n_heads // self.n_kv_heads
|
138 |
+
self.head_dim = model_args.dim // model_args.n_heads
|
139 |
+
|
140 |
+
self.wq = nn.Linear(
|
141 |
+
model_args.dim, model_args.n_heads * self.head_dim, bias=False
|
142 |
+
)
|
143 |
+
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
144 |
+
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
145 |
+
self.wo = nn.Linear(
|
146 |
+
model_args.n_heads * self.head_dim, model_args.dim, bias=False
|
147 |
+
)
|
148 |
+
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
|
149 |
+
|
150 |
+
def init_weights(self, init_std: float):
|
151 |
+
for linear in (self.wq, self.wk, self.wv):
|
152 |
+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
|
153 |
+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
|
154 |
+
|
155 |
+
def forward(
|
156 |
+
self,
|
157 |
+
x: torch.Tensor,
|
158 |
+
freqs_cis: torch.Tensor,
|
159 |
+
):
|
160 |
+
"""
|
161 |
+
Forward pass of the attention module.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
x (torch.Tensor): Input tensor.
|
165 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
torch.Tensor: Output tensor after attention.
|
169 |
+
|
170 |
+
"""
|
171 |
+
|
172 |
+
bs, seqlen, _ = x.shape
|
173 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
174 |
+
|
175 |
+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
|
176 |
+
# local heads from sizes of xq, xk, and xv as TP may have sharded them
|
177 |
+
# after the above linear ops.
|
178 |
+
xq = xq.view(bs, seqlen, -1, self.head_dim)
|
179 |
+
xk = xk.view(bs, seqlen, -1, self.head_dim)
|
180 |
+
xv = xv.view(bs, seqlen, -1, self.head_dim)
|
181 |
+
|
182 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
183 |
+
|
184 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
185 |
+
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
186 |
+
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
187 |
+
|
188 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
189 |
+
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
190 |
+
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
191 |
+
|
192 |
+
output = self.sdpa(xq, xk, xv)
|
193 |
+
|
194 |
+
output = output.transpose(
|
195 |
+
1, 2
|
196 |
+
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
|
197 |
+
output = output.view(bs, seqlen, -1)
|
198 |
+
return self.wo(output)
|
199 |
+
|
200 |
+
|
201 |
+
class FeedForward(nn.Module):
|
202 |
+
"""
|
203 |
+
FeedForward module
|
204 |
+
|
205 |
+
Args:
|
206 |
+
dim (int): Input dimension.
|
207 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
208 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
209 |
+
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
|
210 |
+
|
211 |
+
Attributes:
|
212 |
+
w1 (Linear): Linear transformation for the first layer.
|
213 |
+
w2 (Linear): Linear transformation for the second layer.
|
214 |
+
w3 (Linear): Linear transformation for the third layer.
|
215 |
+
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
dim: int,
|
221 |
+
hidden_dim: int,
|
222 |
+
multiple_of: int,
|
223 |
+
ffn_dim_multiplier: float | None,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
227 |
+
# custom dim factor multiplier
|
228 |
+
if ffn_dim_multiplier is not None:
|
229 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
230 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
231 |
+
|
232 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
233 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
234 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
238 |
+
|
239 |
+
def init_weights(self, init_std: float):
|
240 |
+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
|
241 |
+
for linear in (self.w2, self.w3):
|
242 |
+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
|
243 |
+
|
244 |
+
|
245 |
+
class TransformerBlock(nn.Module):
|
246 |
+
"""
|
247 |
+
TransformerBlock Module
|
248 |
+
|
249 |
+
Args:
|
250 |
+
layer_id (int): Identifier for the layer.
|
251 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
252 |
+
|
253 |
+
Attributes:
|
254 |
+
n_heads (int): Number of attention heads.
|
255 |
+
dim (int): Dimension size of the model.
|
256 |
+
head_dim (int): Dimension size of each attention head.
|
257 |
+
attention (Attention): Attention module.
|
258 |
+
feed_forward (FeedForward): FeedForward module.
|
259 |
+
layer_id (int): Identifier for the layer.
|
260 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
261 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
262 |
+
|
263 |
+
"""
|
264 |
+
|
265 |
+
def __init__(self, layer_id: int, model_args: TransformerModelArgs):
|
266 |
+
super().__init__()
|
267 |
+
self.n_heads = model_args.n_heads
|
268 |
+
self.dim = model_args.dim
|
269 |
+
self.attention = Attention(model_args)
|
270 |
+
|
271 |
+
# use MoE layer for every interleave_moe_layer_step FFN layers
|
272 |
+
self.moe_enabled = (
|
273 |
+
model_args.moe_enabled
|
274 |
+
and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
|
275 |
+
)
|
276 |
+
if self.moe_enabled:
|
277 |
+
self.moe = MoE(model_args)
|
278 |
+
else:
|
279 |
+
self.feed_forward = FeedForward(
|
280 |
+
dim=model_args.dim,
|
281 |
+
hidden_dim=4 * model_args.dim,
|
282 |
+
multiple_of=model_args.multiple_of,
|
283 |
+
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
|
284 |
+
)
|
285 |
+
|
286 |
+
self.layer_id = layer_id
|
287 |
+
self.num_layers = model_args.n_layers
|
288 |
+
|
289 |
+
self.attention_norm = build_norm(
|
290 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
291 |
+
)
|
292 |
+
self.ffn_norm = build_norm(
|
293 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
294 |
+
)
|
295 |
+
|
296 |
+
if model_args.depth_init:
|
297 |
+
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
|
298 |
+
else:
|
299 |
+
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
|
300 |
+
|
301 |
+
def forward(
|
302 |
+
self,
|
303 |
+
x: torch.Tensor,
|
304 |
+
freqs_cis: torch.Tensor,
|
305 |
+
):
|
306 |
+
"""
|
307 |
+
Perform a forward pass through the TransformerBlock.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
x (torch.Tensor): Input tensor.
|
311 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
315 |
+
|
316 |
+
"""
|
317 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis)
|
318 |
+
if self.moe_enabled:
|
319 |
+
out = h + self.moe(self.ffn_norm(h))
|
320 |
+
else:
|
321 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
322 |
+
return out
|
323 |
+
|
324 |
+
def init_weights(self):
|
325 |
+
for norm in (self.attention_norm, self.ffn_norm):
|
326 |
+
norm.reset_parameters()
|
327 |
+
self.attention.init_weights(self.weight_init_std)
|
328 |
+
if self.moe_enabled:
|
329 |
+
self.moe.init_weights(self.weight_init_std)
|
330 |
+
else:
|
331 |
+
self.feed_forward.init_weights(self.weight_init_std)
|
332 |
+
|
333 |
+
|
334 |
+
class Transformer(nn.Module, ModelProtocol):
|
335 |
+
"""
|
336 |
+
Transformer Module
|
337 |
+
|
338 |
+
Args:
|
339 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
340 |
+
|
341 |
+
Attributes:
|
342 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
343 |
+
vocab_size (int): Vocabulary size.
|
344 |
+
n_layers (int): Number of layers in the model.
|
345 |
+
tok_embeddings (ParallelEmbedding): Token embeddings.
|
346 |
+
layers (torch.nn.ModuleList): List of Transformer blocks.
|
347 |
+
norm (RMSNorm): Layer normalization for the model output.
|
348 |
+
output (ColumnParallelLinear): Linear layer for final output.
|
349 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
350 |
+
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, model_args: TransformerModelArgs):
|
354 |
+
super().__init__()
|
355 |
+
self.model_args = model_args
|
356 |
+
self.vocab_size = model_args.vocab_size
|
357 |
+
self.n_layers = model_args.n_layers
|
358 |
+
self.eos_id = model_args.eos_id
|
359 |
+
|
360 |
+
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
|
361 |
+
|
362 |
+
# TODO persistent should be set to false, since this buffer can be recomputed.
|
363 |
+
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
|
364 |
+
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
|
365 |
+
# so we need to fix that. (2) if we initialize pipeline-parallel models from
|
366 |
+
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
|
367 |
+
# initialized by the checkpoint, or we need to add a separate initializer for
|
368 |
+
# just the non-persistent buffers that is called after loading checkpoints.
|
369 |
+
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
|
370 |
+
|
371 |
+
self.layers = torch.nn.ModuleDict()
|
372 |
+
for layer_id in range(model_args.n_layers):
|
373 |
+
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
|
374 |
+
|
375 |
+
self.norm = build_norm(
|
376 |
+
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
|
377 |
+
)
|
378 |
+
|
379 |
+
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
|
380 |
+
self.init_weights()
|
381 |
+
|
382 |
+
def init_weights(
|
383 |
+
self,
|
384 |
+
buffer_device: torch.device | None = None,
|
385 |
+
):
|
386 |
+
"""
|
387 |
+
[Note: On ``init_weights`` vs. ``reset_parameters``]
|
388 |
+
Modules may define ``reset_parameters`` to initialize parameter values.
|
389 |
+
``reset_parameters`` is meant to only initialize directly owned
|
390 |
+
parameters/buffers, not those of their child modules, and it can be
|
391 |
+
used to give the initial values for these tensors.
|
392 |
+
Separately, users may want custom initialization for their modules,
|
393 |
+
different from that in ``reset_parameters``. For this, we define
|
394 |
+
``init_weights``. We only call it in the constructor of this
|
395 |
+
``Transformer`` root module to avoid reinitializing tensors.
|
396 |
+
"""
|
397 |
+
buffer_device = buffer_device or self.freqs_cis.device
|
398 |
+
with torch.device(buffer_device):
|
399 |
+
self.freqs_cis = self._precompute_freqs_cis()
|
400 |
+
if self.tok_embeddings is not None:
|
401 |
+
nn.init.normal_(self.tok_embeddings.weight)
|
402 |
+
for layer in self.layers.values():
|
403 |
+
if layer is not None:
|
404 |
+
layer.init_weights()
|
405 |
+
if self.norm is not None:
|
406 |
+
self.norm.reset_parameters()
|
407 |
+
final_out_std = self.model_args.dim**-0.5
|
408 |
+
cutoff_factor = 3
|
409 |
+
if self.output is not None:
|
410 |
+
nn.init.trunc_normal_(
|
411 |
+
self.output.weight,
|
412 |
+
mean=0.0,
|
413 |
+
std=final_out_std,
|
414 |
+
a=-cutoff_factor * final_out_std,
|
415 |
+
b=cutoff_factor * final_out_std,
|
416 |
+
)
|
417 |
+
|
418 |
+
def _precompute_freqs_cis(self) -> torch.Tensor:
|
419 |
+
return precompute_freqs_cis(
|
420 |
+
self.model_args.dim // self.model_args.n_heads,
|
421 |
+
# Need to compute until at least the max token limit for generation
|
422 |
+
# TODO: explain in docs/composability.md why we removed the 2x
|
423 |
+
# relaxing in our CP enablement PR
|
424 |
+
self.model_args.max_seq_len,
|
425 |
+
self.model_args.rope_theta,
|
426 |
+
)
|
427 |
+
|
428 |
+
def forward(self, tokens: torch.Tensor):
|
429 |
+
"""
|
430 |
+
Perform a forward pass through the Transformer model.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
tokens (torch.Tensor): Input token indices.
|
434 |
+
|
435 |
+
Returns:
|
436 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
437 |
+
|
438 |
+
"""
|
439 |
+
# TODO: We will to change forward() signature to allow tokens to
|
440 |
+
# be always passed in.
|
441 |
+
if self.model_args.use_flex_attn:
|
442 |
+
init_attention_mask(tokens, eos_id=self.eos_id)
|
443 |
+
|
444 |
+
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
|
445 |
+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
|
446 |
+
|
447 |
+
for layer in self.layers.values():
|
448 |
+
h = layer(h, self.freqs_cis)
|
449 |
+
|
450 |
+
h = self.norm(h) if self.norm else h
|
451 |
+
output = self.output(h) if self.output else h
|
452 |
+
return output
|
453 |
+
|
454 |
+
@classmethod
|
455 |
+
def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
|
456 |
+
"""
|
457 |
+
Initialize a Transformer model from a TransformerModelArgs object.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
Transformer: Transformer model.
|
464 |
+
|
465 |
+
"""
|
466 |
+
return cls(model_args)
|
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import Any, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard
|
16 |
+
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
|
17 |
+
from torchtitan.components.checkpoint import MODEL
|
18 |
+
from torchtitan.config_manager import JobConfig
|
19 |
+
from torchtitan.tools.logging import init_logger, logger
|
20 |
+
from torchtitan.train import Trainer
|
21 |
+
|
22 |
+
# Sharding dims for MP checkpoints
|
23 |
+
|
24 |
+
column_parallel = [
|
25 |
+
"tok_embeddings",
|
26 |
+
"wq",
|
27 |
+
"wk",
|
28 |
+
"wv",
|
29 |
+
"wqkv",
|
30 |
+
"w_in_shared_FD",
|
31 |
+
"w_out_eF_D",
|
32 |
+
"w_swiglu_FD",
|
33 |
+
"output",
|
34 |
+
"_linear",
|
35 |
+
"c_fc",
|
36 |
+
"vision_projection",
|
37 |
+
]
|
38 |
+
|
39 |
+
row_parallel = [
|
40 |
+
"wo",
|
41 |
+
"w_out_shared_DF",
|
42 |
+
"w_in_eD_F",
|
43 |
+
"moe_w_swiglu_eD_F",
|
44 |
+
"c_proj",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def convert_to_titan_fqns(fqn: str) -> list[str]:
|
49 |
+
# From the stored checkpoint keys to TorchTitan keys.
|
50 |
+
if "wqkv" in fqn and "layer_norm_weight" not in fqn:
|
51 |
+
ret = []
|
52 |
+
for k in ("wq", "wk", "wv"):
|
53 |
+
ret.append(fqn.replace("wqkv", k))
|
54 |
+
return ret
|
55 |
+
return [fqn]
|
56 |
+
|
57 |
+
|
58 |
+
def get_shard_dim(fqn: str) -> Optional[int]:
|
59 |
+
if "bias" in fqn:
|
60 |
+
# Some bias params are still sharded
|
61 |
+
if "resblocks" in fqn:
|
62 |
+
for k in ("wq", "wk", "wv", "c_fc"):
|
63 |
+
if k in fqn:
|
64 |
+
return 0
|
65 |
+
return None
|
66 |
+
elif any([x in fqn for x in column_parallel]):
|
67 |
+
return 0
|
68 |
+
elif any([x in fqn for x in row_parallel]):
|
69 |
+
return 1
|
70 |
+
else:
|
71 |
+
return None
|
72 |
+
|
73 |
+
|
74 |
+
def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
75 |
+
qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards]
|
76 |
+
q = torch.cat([qkv[0] for qkv in qkvs], dim=0)
|
77 |
+
k = torch.cat([qkv[1] for qkv in qkvs], dim=0)
|
78 |
+
v = torch.cat([qkv[2] for qkv in qkvs], dim=0)
|
79 |
+
return q, k, v
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class _Assignment:
|
84 |
+
loader_id: int
|
85 |
+
filename: str
|
86 |
+
fqns: tuple[str, ...]
|
87 |
+
shapes: tuple[torch.Size, ...]
|
88 |
+
dtypes: tuple[torch.dtype, ...]
|
89 |
+
|
90 |
+
|
91 |
+
@dataclass
|
92 |
+
class _AssignmentRound:
|
93 |
+
loader_assignments: dict[int, _Assignment] # List of assignments for each loader
|
94 |
+
|
95 |
+
|
96 |
+
class CheckpointConverter:
|
97 |
+
TOTAL_SHARDS = 8
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
process_group: dist.ProcessGroup,
|
102 |
+
path: str,
|
103 |
+
loader_every_n_ranks: int = 8,
|
104 |
+
) -> None:
|
105 |
+
self.path = path
|
106 |
+
self.pg = process_group
|
107 |
+
self.my_rank = dist.get_rank(self.pg)
|
108 |
+
self.loader_every_n_ranks = loader_every_n_ranks
|
109 |
+
self.loader_id = self.my_rank // loader_every_n_ranks
|
110 |
+
self.should_load = (
|
111 |
+
self.my_rank % loader_every_n_ranks == 0
|
112 |
+
and self.loader_id < CheckpointConverter.TOTAL_SHARDS
|
113 |
+
)
|
114 |
+
self.total_loader = CheckpointConverter.TOTAL_SHARDS
|
115 |
+
self.titan_fqn_to_stored_fqn: dict[str, str] = {}
|
116 |
+
self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {}
|
117 |
+
self.total_send_bytes = 0
|
118 |
+
self.total_recv_bytes = 0
|
119 |
+
|
120 |
+
def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
121 |
+
begin = time.time()
|
122 |
+
self._load_metadata()
|
123 |
+
self._create_fqn_mappings(state_dict)
|
124 |
+
rounds = self._get_load_assignments(state_dict)
|
125 |
+
|
126 |
+
for assignments in rounds:
|
127 |
+
loader_assignments = assignments.loader_assignments
|
128 |
+
loaded_state_dict = None
|
129 |
+
# Let each loader to load its own data and move to its GPU.
|
130 |
+
for i in range(self.total_loader):
|
131 |
+
# This loader doesn't have any loading assignment for this round.
|
132 |
+
if i not in loader_assignments:
|
133 |
+
continue
|
134 |
+
# This rank is not the loader
|
135 |
+
if i != self.loader_id or not self.should_load:
|
136 |
+
continue
|
137 |
+
loaded_state_dict = self._load_round(loader_assignments[i])
|
138 |
+
|
139 |
+
results = []
|
140 |
+
for i in range(self.total_loader):
|
141 |
+
if i not in loader_assignments:
|
142 |
+
continue
|
143 |
+
|
144 |
+
if i == self.loader_id and self.should_load:
|
145 |
+
# This rank is the loader. It needs to send the loaded data to
|
146 |
+
# the other ranks.
|
147 |
+
assert loaded_state_dict is not None
|
148 |
+
results.append(
|
149 |
+
self._reshard_send(loader_assignments[i], loaded_state_dict)
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
results.append(
|
153 |
+
self._reshard_receive(loader_assignments[i], state_dict)
|
154 |
+
)
|
155 |
+
|
156 |
+
self._reshard(results, state_dict)
|
157 |
+
|
158 |
+
torch.cuda.synchronize()
|
159 |
+
logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.")
|
160 |
+
logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB")
|
161 |
+
logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB")
|
162 |
+
return state_dict
|
163 |
+
|
164 |
+
def _get_file_path(self, loader_id: int) -> str:
|
165 |
+
return os.path.join(self.path, f"consolidated.0{loader_id}.pth")
|
166 |
+
|
167 |
+
def _load_metadata(self) -> None:
|
168 |
+
if not self.should_load:
|
169 |
+
self.read_dict = {}
|
170 |
+
return
|
171 |
+
self.read_dict = torch.load(
|
172 |
+
self._get_file_path(self.loader_id),
|
173 |
+
mmap=True,
|
174 |
+
weights_only=False,
|
175 |
+
)
|
176 |
+
|
177 |
+
def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None:
|
178 |
+
if not self.read_dict:
|
179 |
+
return
|
180 |
+
|
181 |
+
# Create the mapping from the stored checkpoint keys to TorchTitan keys.
|
182 |
+
for fqn in list(self.read_dict.keys()):
|
183 |
+
titan_fqns = convert_to_titan_fqns(fqn)
|
184 |
+
# We don't know how to process _extra_state
|
185 |
+
if "_extra_state" in fqn:
|
186 |
+
self.read_dict.pop(fqn)
|
187 |
+
continue
|
188 |
+
|
189 |
+
if titan_fqns[0] not in state_dict:
|
190 |
+
for titan_fqn in titan_fqns:
|
191 |
+
assert titan_fqns[0] not in state_dict
|
192 |
+
self.read_dict.pop(fqn)
|
193 |
+
continue
|
194 |
+
self.stored_fqn_to_titan_fqn[fqn] = titan_fqns
|
195 |
+
for titan_fqn in titan_fqns:
|
196 |
+
self.titan_fqn_to_stored_fqn[titan_fqn] = fqn
|
197 |
+
|
198 |
+
assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), (
|
199 |
+
set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()),
|
200 |
+
set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()),
|
201 |
+
)
|
202 |
+
|
203 |
+
def _get_load_assignments(
|
204 |
+
self, state_dict: dict[str, torch.Tensor]
|
205 |
+
) -> list[_AssignmentRound]:
|
206 |
+
if self.my_rank == 0:
|
207 |
+
rounds: list[_AssignmentRound] = []
|
208 |
+
size = 0
|
209 |
+
fqns = []
|
210 |
+
shapes = []
|
211 |
+
dtypes = []
|
212 |
+
|
213 |
+
# All loader must load all the FQNs because the checkpoint is purely TP sharded.
|
214 |
+
all_keys = list(self.read_dict.keys())
|
215 |
+
for fqn in all_keys:
|
216 |
+
fqns.append(fqn)
|
217 |
+
shapes.append(self.read_dict[fqn].shape)
|
218 |
+
dtypes.append(self.read_dict[fqn].dtype)
|
219 |
+
size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size()
|
220 |
+
if size < 1e9 and fqn != all_keys[-1]:
|
221 |
+
continue
|
222 |
+
|
223 |
+
logger.info(f"Adding {fqns} to round {len(rounds)}")
|
224 |
+
round_assignment = _AssignmentRound(loader_assignments={})
|
225 |
+
for loader_id in range(self.total_loader):
|
226 |
+
path = self._get_file_path(loader_id)
|
227 |
+
round_assignment.loader_assignments[loader_id] = _Assignment(
|
228 |
+
filename=path,
|
229 |
+
fqns=tuple(fqns),
|
230 |
+
shapes=tuple(shapes),
|
231 |
+
dtypes=tuple(dtypes),
|
232 |
+
loader_id=loader_id,
|
233 |
+
)
|
234 |
+
rounds.append(round_assignment)
|
235 |
+
size = 0
|
236 |
+
fqns.clear()
|
237 |
+
shapes.clear()
|
238 |
+
dtypes.clear()
|
239 |
+
|
240 |
+
object_list: list[Any] = [
|
241 |
+
rounds,
|
242 |
+
self.titan_fqn_to_stored_fqn,
|
243 |
+
self.stored_fqn_to_titan_fqn,
|
244 |
+
]
|
245 |
+
else:
|
246 |
+
object_list = [None, None, None]
|
247 |
+
|
248 |
+
dist.broadcast_object_list(object_list, src=0, group=self.pg)
|
249 |
+
rounds = object_list[0]
|
250 |
+
self.titan_fqn_to_stored_fqn = object_list[1]
|
251 |
+
self.stored_fqn_to_titan_fqn = object_list[2]
|
252 |
+
return rounds
|
253 |
+
|
254 |
+
def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]:
|
255 |
+
ret = {}
|
256 |
+
assert self.read_dict
|
257 |
+
for fqn in assignment.fqns:
|
258 |
+
ret[fqn] = self.read_dict[fqn].to(device="cuda")
|
259 |
+
return ret
|
260 |
+
|
261 |
+
def _reshard_send(
|
262 |
+
self,
|
263 |
+
assignment: _Assignment,
|
264 |
+
loaded_state_dict: dict[str, torch.Tensor],
|
265 |
+
) -> dict[str, torch.Tensor]:
|
266 |
+
flatten_tensors = [t.flatten() for t in loaded_state_dict.values()]
|
267 |
+
flatten_tensor = torch.concat(flatten_tensors)
|
268 |
+
assert self.loader_id == assignment.loader_id
|
269 |
+
rank = self.loader_id * self.loader_every_n_ranks
|
270 |
+
assert rank == self.my_rank
|
271 |
+
logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}")
|
272 |
+
logger.info(f"Sending {assignment.fqns}")
|
273 |
+
dist.broadcast(flatten_tensor, src=rank, group=self.pg)
|
274 |
+
self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
|
275 |
+
return loaded_state_dict
|
276 |
+
|
277 |
+
def _reshard_receive(
|
278 |
+
self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
|
279 |
+
) -> dict[str, torch.Tensor]:
|
280 |
+
flatten_tensor = torch.empty(
|
281 |
+
sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
|
282 |
+
dtype=assignment.dtypes[0],
|
283 |
+
device="cuda",
|
284 |
+
)
|
285 |
+
rank = assignment.loader_id * self.loader_every_n_ranks
|
286 |
+
dist.broadcast(flatten_tensor, src=rank, group=self.pg)
|
287 |
+
self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
|
288 |
+
|
289 |
+
ret: dict[str, torch.Tensor] = {}
|
290 |
+
loc = 0
|
291 |
+
for fqn, shape, dtype in zip(
|
292 |
+
assignment.fqns, assignment.shapes, assignment.dtypes
|
293 |
+
):
|
294 |
+
n_ele = math.prod(shape)
|
295 |
+
ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape)
|
296 |
+
loc += n_ele
|
297 |
+
return ret
|
298 |
+
|
299 |
+
def _reshard(
|
300 |
+
self,
|
301 |
+
results: list[dict[str, torch.Tensor]],
|
302 |
+
state_dict: dict[str, torch.Tensor],
|
303 |
+
) -> None:
|
304 |
+
def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]):
|
305 |
+
titan_fqns = self.stored_fqn_to_titan_fqn[fqn]
|
306 |
+
assert len(titan_fqns) == len(full_tensors)
|
307 |
+
for titan_fqn, full_tensor in zip(titan_fqns, full_tensors):
|
308 |
+
dtensor = state_dict[titan_fqn]
|
309 |
+
logger.info(f"{titan_fqn} {full_tensor.sum()}")
|
310 |
+
assert isinstance(dtensor, DTensor)
|
311 |
+
shape, offset = compute_local_shape_and_global_offset(
|
312 |
+
full_tensor.shape, dtensor.device_mesh, dtensor.placements
|
313 |
+
)
|
314 |
+
slices = [
|
315 |
+
slice(cur_offset, cur_offset + cur_shape)
|
316 |
+
for cur_shape, cur_offset in zip(shape, offset)
|
317 |
+
]
|
318 |
+
logger.info(
|
319 |
+
f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} "
|
320 |
+
f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} "
|
321 |
+
f"{dtensor.placements=} {dtensor.device_mesh=} "
|
322 |
+
)
|
323 |
+
dtensor.to_local().copy_(full_tensor[slices])
|
324 |
+
|
325 |
+
def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
326 |
+
if "wqkv" in fqn:
|
327 |
+
if "layer_norm" in fqn:
|
328 |
+
return (shards[0],)
|
329 |
+
return split_fused_qkv(shards)
|
330 |
+
|
331 |
+
shard_dim = get_shard_dim(fqn)
|
332 |
+
if shard_dim is None:
|
333 |
+
return (shards[0],)
|
334 |
+
return (torch.cat(shards, dim=shard_dim),)
|
335 |
+
|
336 |
+
fqns = list(results[0].keys())
|
337 |
+
for result in results:
|
338 |
+
assert list(result.keys()) == fqns
|
339 |
+
|
340 |
+
for fqn in fqns:
|
341 |
+
full_tensors = _concat_shards(fqn, [result[fqn] for result in results])
|
342 |
+
_inplace_copy(fqn, full_tensors)
|
343 |
+
|
344 |
+
|
345 |
+
def _create_verified_state_dict(
|
346 |
+
pg: dist.ProcessGroup, mesh: DeviceMesh
|
347 |
+
) -> dict[str, torch.Tensor]:
|
348 |
+
placements = [Shard(0)]
|
349 |
+
state_dict = {
|
350 |
+
"tok_embeddings.weight": torch.rand(
|
351 |
+
25256 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
352 |
+
),
|
353 |
+
"layers.47.attention.wqkv.layer_norm_weight": torch.rand(
|
354 |
+
5120, device="cuda", dtype=torch.bfloat16
|
355 |
+
),
|
356 |
+
"layers.47.attention.wq.weight": torch.rand(
|
357 |
+
640 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
358 |
+
),
|
359 |
+
"layers.47.attention.wk.weight": torch.rand(
|
360 |
+
128 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
361 |
+
),
|
362 |
+
"layers.47.attention.wv.weight": torch.rand(
|
363 |
+
128 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
364 |
+
),
|
365 |
+
"layers.47.attention.wo.weight": torch.rand(
|
366 |
+
5120, 640 * 8, device="cuda", dtype=torch.bfloat16
|
367 |
+
),
|
368 |
+
# "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16),
|
369 |
+
# "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
|
370 |
+
# "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
|
371 |
+
"layers.47.feed_forward.w_in_shared_FD.weight": torch.rand(
|
372 |
+
1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
373 |
+
),
|
374 |
+
"layers.47.feed_forward.w_out_shared_DF.weight": torch.rand(
|
375 |
+
5120, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
376 |
+
),
|
377 |
+
"layers.47.feed_forward.w_swiglu_FD.weight": torch.rand(
|
378 |
+
1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
379 |
+
),
|
380 |
+
"layers.47.feed_forward.norm.weight": torch.rand(
|
381 |
+
5120, device="cuda", dtype=torch.bfloat16
|
382 |
+
),
|
383 |
+
"layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand(
|
384 |
+
655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
385 |
+
),
|
386 |
+
"layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand(
|
387 |
+
131072 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
388 |
+
),
|
389 |
+
"layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand(
|
390 |
+
655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
391 |
+
),
|
392 |
+
}
|
393 |
+
return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()}
|
394 |
+
|
395 |
+
|
396 |
+
def _verify_state_dict(
|
397 |
+
state_dict: dict[str, torch.Tensor], path: str, rank: int
|
398 |
+
) -> None:
|
399 |
+
stored_state_dicts = [
|
400 |
+
torch.load(
|
401 |
+
os.path.join(path, f"consolidated.0{i}.pth"),
|
402 |
+
map_location="cpu",
|
403 |
+
weights_only=False,
|
404 |
+
mmap=True,
|
405 |
+
)
|
406 |
+
for i in range(8)
|
407 |
+
]
|
408 |
+
|
409 |
+
def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None:
|
410 |
+
logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ")
|
411 |
+
shards = [stored_state_dicts[i][fqn] for i in range(8)]
|
412 |
+
full_tensor = dtensor.full_tensor()
|
413 |
+
logger.info(f"Gather {fqn} {full_tensor.shape} completely.")
|
414 |
+
|
415 |
+
if rank > 0:
|
416 |
+
return
|
417 |
+
|
418 |
+
if len(shards[0].shape) == 1:
|
419 |
+
assert full_tensor.shape == shards[0].shape, fqn
|
420 |
+
assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn
|
421 |
+
return
|
422 |
+
elif shards[0].shape[0] == full_tensor.shape[0]:
|
423 |
+
concat_shards = torch.cat(shards, dim=1)
|
424 |
+
logger.info(f"Load {fqn} completely.")
|
425 |
+
elif shards[0].shape[1] == full_tensor.shape[1]:
|
426 |
+
concat_shards = torch.cat(shards, dim=0)
|
427 |
+
logger.info(f"Load {fqn} completely.")
|
428 |
+
|
429 |
+
concat_shards = concat_shards.to(device="cuda")
|
430 |
+
logger.info(f"Move to GPU {fqn} completely.")
|
431 |
+
|
432 |
+
assert concat_shards.shape == full_tensor.shape, fqn
|
433 |
+
assert concat_shards.dtype == full_tensor.dtype, fqn
|
434 |
+
assert concat_shards.device == full_tensor.device, fqn
|
435 |
+
assert torch.allclose(concat_shards, full_tensor), fqn
|
436 |
+
|
437 |
+
for k, v in state_dict.items():
|
438 |
+
if "wq" in k and "wqkv" not in k:
|
439 |
+
pass
|
440 |
+
elif "wk" in k:
|
441 |
+
pass
|
442 |
+
elif "wv" in k:
|
443 |
+
pass
|
444 |
+
else:
|
445 |
+
assert v is not None, k
|
446 |
+
read_and_verify_tensor(k, v)
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
init_logger()
|
451 |
+
config = JobConfig()
|
452 |
+
config.parser.add_argument(
|
453 |
+
"--checkpoint.convert_path",
|
454 |
+
type=str,
|
455 |
+
default="",
|
456 |
+
help="""Specify the path of the target checkpoint to convert.""",
|
457 |
+
)
|
458 |
+
config.parser.add_argument(
|
459 |
+
"--checkpoint.convert_load_every_n_ranks",
|
460 |
+
type=int,
|
461 |
+
default=8,
|
462 |
+
help="""
|
463 |
+
Specify the interval at which ranks are assigned to load checkpoints.
|
464 |
+
|
465 |
+
For example, if this number is 4, then ranks 0, 4, 8, ... will load the
|
466 |
+
checkpoint. Each loader is responsible for loading one file. If there
|
467 |
+
are more loaders than files, only the first few loaders will be assigned
|
468 |
+
to load the checkpoint. The default value is 8.
|
469 |
+
""",
|
470 |
+
)
|
471 |
+
config.parser.add_argument(
|
472 |
+
"--checkpoint.fake_model",
|
473 |
+
action="store_true",
|
474 |
+
help="""If true, the model will be fake.""",
|
475 |
+
)
|
476 |
+
config.parse_args()
|
477 |
+
assert config.checkpoint.convert_path != ""
|
478 |
+
|
479 |
+
trainer: Optional[Trainer] = None
|
480 |
+
|
481 |
+
try:
|
482 |
+
trainer = Trainer(config)
|
483 |
+
if os.path.exists(trainer.checkpointer.folder):
|
484 |
+
raise RuntimeError(
|
485 |
+
"The checkpoint folder already exists. Abort to avoid overwriting "
|
486 |
+
f"the checkpoint. {trainer.checkpointer.folder=}"
|
487 |
+
)
|
488 |
+
if config.checkpoint.fake_model:
|
489 |
+
state_dict = _create_verified_state_dict(
|
490 |
+
trainer.world_mesh.get_group(), trainer.world_mesh
|
491 |
+
)
|
492 |
+
else:
|
493 |
+
state_dict = trainer.checkpointer.states[MODEL].state_dict()
|
494 |
+
|
495 |
+
size = 0
|
496 |
+
for v in state_dict.values():
|
497 |
+
size += v.numel() * v.element_size()
|
498 |
+
logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
|
499 |
+
|
500 |
+
# Do not support PP yet, we will need to iterate over the PP dimension and
|
501 |
+
# extract the corresponding state_dict and device_mesh.
|
502 |
+
if "freq_cis" in state_dict:
|
503 |
+
state_dict.pop("freqs_cis")
|
504 |
+
|
505 |
+
state_dict = CheckpointConverter(
|
506 |
+
process_group=trainer.world_mesh.get_group(),
|
507 |
+
path=config.checkpoint.convert_path,
|
508 |
+
loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks,
|
509 |
+
).convert(state_dict)
|
510 |
+
|
511 |
+
class DummyModel:
|
512 |
+
def __init__(self, state_dict: dict[str, torch.Tensor]) -> None:
|
513 |
+
self._state_dict = state_dict
|
514 |
+
|
515 |
+
def state_dict(self) -> dict[str, torch.Tensor]:
|
516 |
+
return self._state_dict
|
517 |
+
|
518 |
+
if config.checkpoint.fake_model:
|
519 |
+
begin = time.time()
|
520 |
+
_verify_state_dict(
|
521 |
+
state_dict,
|
522 |
+
config.checkpoint.convert_path,
|
523 |
+
trainer.world_mesh.get_rank(),
|
524 |
+
)
|
525 |
+
dist.barrier()
|
526 |
+
logger.info(f"Verifies state_dict {time.time() - begin}.")
|
527 |
+
else:
|
528 |
+
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
|
529 |
+
state_dict["freqs_cis"] = None
|
530 |
+
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
|
531 |
+
trainer.checkpointer.model_weights_only = True
|
532 |
+
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
|
533 |
+
trainer.checkpointer.save(curr_step=0, force=True)
|
534 |
+
time.sleep(2)
|
535 |
+
finally:
|
536 |
+
pass
|
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: this toml config is still under development
|
2 |
+
|
3 |
+
[job]
|
4 |
+
dump_folder = "./outputs"
|
5 |
+
description = "Llama 4 Maverick 17Bx128E training"
|
6 |
+
|
7 |
+
[profiling]
|
8 |
+
enable_profiling = false
|
9 |
+
save_traces_folder = "profile_trace"
|
10 |
+
profile_freq = 100
|
11 |
+
|
12 |
+
[metrics]
|
13 |
+
log_freq = 10
|
14 |
+
enable_tensorboard = false
|
15 |
+
save_tb_folder = "tb"
|
16 |
+
|
17 |
+
[model]
|
18 |
+
name = "llama4"
|
19 |
+
flavor = "17bx128e"
|
20 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
21 |
+
tokenizer_path = "./assets/tokenizer/tokenizer.model"
|
22 |
+
# converters = "float8"
|
23 |
+
|
24 |
+
[optimizer]
|
25 |
+
name = "AdamW"
|
26 |
+
lr = 4e-3
|
27 |
+
eps = 1e-15
|
28 |
+
|
29 |
+
[lr_scheduler]
|
30 |
+
warmup_steps = 600
|
31 |
+
lr_min = 0.1
|
32 |
+
|
33 |
+
[training]
|
34 |
+
batch_size = 1
|
35 |
+
seq_len = 8192
|
36 |
+
max_norm = 1.0 # grad norm clipping
|
37 |
+
steps = 3000
|
38 |
+
compile = false
|
39 |
+
dataset = "c4"
|
40 |
+
|
41 |
+
[parallelism]
|
42 |
+
data_parallel_replicate_degree = 1
|
43 |
+
data_parallel_shard_degree = -1
|
44 |
+
tensor_parallel_degree = 8
|
45 |
+
enable_async_tensor_parallel = false
|
46 |
+
pipeline_parallel_degree = 4
|
47 |
+
# pipeline_parallel_schedule = "interleaved1f1b"
|
48 |
+
# pipeline_parallel_microbatches = 2
|
49 |
+
context_parallel_degree = 1
|
50 |
+
|
51 |
+
[checkpoint]
|
52 |
+
enable_checkpoint = false
|
53 |
+
folder = "checkpoint"
|
54 |
+
interval = 500
|
55 |
+
model_weights_only = false
|
56 |
+
export_dtype = "float32"
|
57 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
58 |
+
|
59 |
+
[activation_checkpoint]
|
60 |
+
mode = 'full' # ['none', 'selective', 'full']
|
61 |
+
|
62 |
+
[float8]
|
63 |
+
enable_fsdp_float8_all_gather = false
|
64 |
+
precompute_float8_dynamic_scale_for_fsdp = false
|
65 |
+
filter_fqns = "output,router.gate"
|
torchtitan/experiments/multimodal/mm_dataset.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from datasets import Dataset, load_dataset
|
13 |
+
from datasets.distributed import split_dataset_by_node
|
14 |
+
|
15 |
+
from mm_collator import MultiModalCollator
|
16 |
+
from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer
|
17 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
18 |
+
from torch.utils.data import IterableDataset
|
19 |
+
from transform import CLIPTransform
|
20 |
+
from utils import load_image
|
21 |
+
|
22 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
23 |
+
from torchtitan.config_manager import JobConfig
|
24 |
+
from torchtitan.tools.logging import logger
|
25 |
+
|
26 |
+
|
27 |
+
def _load_obelics_dataset(dataset_path: str):
|
28 |
+
"""Load C4 dataset with default configuration."""
|
29 |
+
return load_dataset(dataset_path, split="train", streaming=True)
|
30 |
+
|
31 |
+
|
32 |
+
def _process_obelics_sample(
|
33 |
+
sample: dict[str, Any], image_token: str = "<|image|>"
|
34 |
+
) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]:
|
35 |
+
"""
|
36 |
+
This function formats samples from the OBELICS dataset
|
37 |
+
Returns:
|
38 |
+
Dict[str, Any]: The transformed sample with the following fields:
|
39 |
+
- images: List[PIL.Image.Image] with the loaded images
|
40 |
+
- text: str with the text of the sample ready to be tokenized including the image tokens
|
41 |
+
Example:
|
42 |
+
>>> formatted_sample = format_obelics(sample, image_token="<|image|>")
|
43 |
+
>>> print(formatted_sample["text"])
|
44 |
+
... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :("
|
45 |
+
"""
|
46 |
+
sample_images = [image for image in sample["images"] if image is not None]
|
47 |
+
sample_text = [
|
48 |
+
text if text is not None else image_token for text in sample["texts"]
|
49 |
+
]
|
50 |
+
return {
|
51 |
+
"images": [load_image(image) for image in sample_images],
|
52 |
+
"text": "".join(map(str, sample_text)),
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class DatasetConfig:
|
58 |
+
path: str
|
59 |
+
loader: Callable
|
60 |
+
sample_processor: Callable
|
61 |
+
|
62 |
+
|
63 |
+
# Add your dataset here here - more information at docs/datasets.md
|
64 |
+
MM_DATASETS = {
|
65 |
+
"obelics": DatasetConfig(
|
66 |
+
path="HuggingFaceM4/OBELICS",
|
67 |
+
loader=_load_obelics_dataset,
|
68 |
+
sample_processor=_process_obelics_sample,
|
69 |
+
),
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def _validate_mm_dataset(
|
74 |
+
dataset_name: str, dataset_path: str = None
|
75 |
+
) -> tuple[str, Callable, Callable]:
|
76 |
+
"""Validate dataset name and path."""
|
77 |
+
if dataset_name not in MM_DATASETS:
|
78 |
+
raise ValueError(
|
79 |
+
f"Dataset {dataset_name} is not supported. "
|
80 |
+
f"Supported datasets are: {list(MM_DATASETS.keys())}"
|
81 |
+
)
|
82 |
+
|
83 |
+
config = MM_DATASETS[dataset_name]
|
84 |
+
path = dataset_path or config.path
|
85 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
86 |
+
return path, config.loader, config.sample_processor
|
87 |
+
|
88 |
+
|
89 |
+
class MultiModalDataset(IterableDataset, Stateful):
|
90 |
+
"""PyTorch MultiModal Dataset.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
dataset_name (str): name of the dataset to load
|
94 |
+
tokenizer (Tokenizer):
|
95 |
+
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
|
96 |
+
world_size (int): number of data parallel processes participating in training
|
97 |
+
rank (int): rank of the current data parallel process
|
98 |
+
infinite (bool): whether to loop infinitely over the dataset
|
99 |
+
|
100 |
+
We currently ONLY support the OBELICS dataset
|
101 |
+
|
102 |
+
Example use:
|
103 |
+
>>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer)
|
104 |
+
>>> for batch in Dataloader(ds, batch_size=8):
|
105 |
+
print(f"Batch size: {len(batch)}")
|
106 |
+
Batch size: 8
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dataset_name: str,
|
112 |
+
dataset_path: Optional[str],
|
113 |
+
tokenizer: Tokenizer,
|
114 |
+
image_token: str = "<|image|>",
|
115 |
+
tile_size: int = 448,
|
116 |
+
max_num_tiles: int = 4,
|
117 |
+
seq_len: int = 2048,
|
118 |
+
dp_rank: int = 0,
|
119 |
+
dp_world_size: int = 1,
|
120 |
+
infinite: bool = False,
|
121 |
+
) -> None:
|
122 |
+
# Force lowercase for consistent comparison
|
123 |
+
dataset_name = dataset_name.lower()
|
124 |
+
|
125 |
+
path, dataset_loader, sample_processor = _validate_mm_dataset(
|
126 |
+
dataset_name, dataset_path
|
127 |
+
)
|
128 |
+
ds = dataset_loader(path)
|
129 |
+
|
130 |
+
# TODO: support shuffling
|
131 |
+
self.dataset_name = dataset_name
|
132 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
133 |
+
self._tokenizer = tokenizer
|
134 |
+
self.seq_len = seq_len
|
135 |
+
self.infinite = infinite
|
136 |
+
self._sample_processor = sample_processor
|
137 |
+
self.image_token = (
|
138 |
+
image_token # TODO(tj.solergibert) Add `image_token` to JobConfig
|
139 |
+
)
|
140 |
+
# TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig
|
141 |
+
self.transform_image = CLIPTransform(
|
142 |
+
image_mean=(
|
143 |
+
0.48145466,
|
144 |
+
0.4578275,
|
145 |
+
0.40821073,
|
146 |
+
), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?,
|
147 |
+
image_std=(0.26862954, 0.26130258, 0.27577711),
|
148 |
+
tile_size=tile_size,
|
149 |
+
possible_resolutions=None,
|
150 |
+
max_num_tiles=max_num_tiles,
|
151 |
+
resample="bilinear",
|
152 |
+
resize_to_max_canvas=False,
|
153 |
+
)
|
154 |
+
|
155 |
+
# variables for checkpointing
|
156 |
+
self._sample_idx = 0
|
157 |
+
|
158 |
+
def __iter__(self):
|
159 |
+
|
160 |
+
while True:
|
161 |
+
for sample in self._get_data_iter():
|
162 |
+
try:
|
163 |
+
sample = self._sample_processor(
|
164 |
+
sample, image_token=self.image_token
|
165 |
+
)
|
166 |
+
except Exception:
|
167 |
+
continue
|
168 |
+
self._sample_idx += 1
|
169 |
+
|
170 |
+
# CLIP Transform
|
171 |
+
encoder_input = {"images": [], "aspect_ratio": []}
|
172 |
+
for image in sample["images"]:
|
173 |
+
out = self.transform_image(image)
|
174 |
+
encoder_input["images"].append(out["image"])
|
175 |
+
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
|
176 |
+
sample["encoder_input"] = encoder_input
|
177 |
+
|
178 |
+
# Tokenize
|
179 |
+
tokens = self._tokenizer.encode(
|
180 |
+
sample["text"],
|
181 |
+
bos=True,
|
182 |
+
eos=True,
|
183 |
+
allowed_special=set(["<|image|>"]),
|
184 |
+
)
|
185 |
+
sample["input_ids"] = torch.LongTensor(tokens[:-1])
|
186 |
+
sample["labels"] = torch.LongTensor(tokens[1:])
|
187 |
+
# Mask BOS, EOS & image tokens from the loss
|
188 |
+
sample["labels"] = torch.where(
|
189 |
+
torch.isin(
|
190 |
+
sample["labels"],
|
191 |
+
torch.LongTensor(
|
192 |
+
[
|
193 |
+
self._tokenizer.bos_id,
|
194 |
+
self._tokenizer.eos_id,
|
195 |
+
self._tokenizer.image_id,
|
196 |
+
]
|
197 |
+
),
|
198 |
+
),
|
199 |
+
IGNORE_INDEX,
|
200 |
+
sample["labels"],
|
201 |
+
)
|
202 |
+
# Truncate
|
203 |
+
sample["input_ids"], sample["labels"] = (
|
204 |
+
sample["input_ids"][: self.seq_len],
|
205 |
+
sample["labels"][: self.seq_len],
|
206 |
+
)
|
207 |
+
yield sample
|
208 |
+
|
209 |
+
if not self.infinite:
|
210 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
211 |
+
break
|
212 |
+
else:
|
213 |
+
# Reset offset for the next iteration
|
214 |
+
self._sample_idx = 0
|
215 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
216 |
+
|
217 |
+
def _get_data_iter(self):
|
218 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
219 |
+
return iter([])
|
220 |
+
|
221 |
+
it = iter(self._data)
|
222 |
+
for _ in range(self._sample_idx):
|
223 |
+
next(it)
|
224 |
+
return it
|
225 |
+
|
226 |
+
def load_state_dict(self, state_dict):
|
227 |
+
self._sample_idx = state_dict["sample_idx"]
|
228 |
+
|
229 |
+
def state_dict(self):
|
230 |
+
return {"sample_idx": self._sample_idx}
|
231 |
+
|
232 |
+
|
233 |
+
def build_mm_dataloader(
|
234 |
+
dp_world_size: int,
|
235 |
+
dp_rank: int,
|
236 |
+
tokenizer: Tokenizer,
|
237 |
+
job_config: JobConfig,
|
238 |
+
infinite: bool = True,
|
239 |
+
) -> ParallelAwareDataloader:
|
240 |
+
"""Build a data loader for HuggingFace datasets."""
|
241 |
+
dataset_name = job_config.training.dataset
|
242 |
+
dataset_path = job_config.training.dataset_path
|
243 |
+
batch_size = job_config.training.batch_size
|
244 |
+
seq_len = job_config.training.seq_len
|
245 |
+
pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig
|
246 |
+
padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig
|
247 |
+
|
248 |
+
hf_ds = MultiModalDataset(
|
249 |
+
dataset_name=dataset_name,
|
250 |
+
dataset_path=dataset_path,
|
251 |
+
tokenizer=tokenizer,
|
252 |
+
seq_len=seq_len,
|
253 |
+
dp_rank=dp_rank,
|
254 |
+
dp_world_size=dp_world_size,
|
255 |
+
infinite=infinite,
|
256 |
+
)
|
257 |
+
|
258 |
+
collate_fn = MultiModalCollator(
|
259 |
+
padding_idx=padding_idx, pad_max_tiles=pad_max_tiles
|
260 |
+
)
|
261 |
+
|
262 |
+
return ParallelAwareDataloader(
|
263 |
+
dataset=hf_ds,
|
264 |
+
dp_rank=dp_rank,
|
265 |
+
dp_world_size=dp_world_size,
|
266 |
+
batch_size=batch_size,
|
267 |
+
collate_fn=collate_fn,
|
268 |
+
)
|
torchtitan/experiments/multimodal/requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
torchvision
|
torchtitan/experiments/multimodal/tests/test_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def fixed_init_tensor(
|
16 |
+
shape: torch.Size,
|
17 |
+
min_val: Union[float, int] = 0.0,
|
18 |
+
max_val: Union[float, int] = 1.0,
|
19 |
+
nonlinear: bool = False,
|
20 |
+
dtype: torch.dtype = torch.float,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Utility for generating deterministic tensors of a given shape. In general stuff
|
24 |
+
like torch.ones, torch.eye, etc can result in trivial outputs. This utility
|
25 |
+
generates a range tensor [min_val, max_val) of a specified dtype, applies
|
26 |
+
a sine function if nonlinear=True, then reshapes to the appropriate shape.
|
27 |
+
"""
|
28 |
+
n_elements = math.prod(shape)
|
29 |
+
step_size = (max_val - min_val) / n_elements
|
30 |
+
x = torch.arange(min_val, max_val, step_size, dtype=dtype)
|
31 |
+
x = x.reshape(shape)
|
32 |
+
if nonlinear:
|
33 |
+
return torch.sin(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad
|
38 |
+
def fixed_init_model(
|
39 |
+
model: nn.Module,
|
40 |
+
min_val: Union[float, int] = 0.0,
|
41 |
+
max_val: Union[float, int] = 1.0,
|
42 |
+
nonlinear: bool = False,
|
43 |
+
dtype: Optional[torch.dtype] = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
This utility initializes all parameters of a model deterministically using the
|
47 |
+
function fixed_init_tensor above. See that docstring for details of each parameter.
|
48 |
+
"""
|
49 |
+
for _, param in model.named_parameters():
|
50 |
+
param.copy_(
|
51 |
+
fixed_init_tensor(
|
52 |
+
param.shape,
|
53 |
+
min_val=min_val,
|
54 |
+
max_val=max_val,
|
55 |
+
nonlinear=nonlinear,
|
56 |
+
dtype=param.dtype if dtype is None else dtype,
|
57 |
+
)
|
58 |
+
)
|
torchtitan/experiments/multimodal/tokenizer/tiktoken.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
9 |
+
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import (
|
13 |
+
AbstractSet,
|
14 |
+
Any,
|
15 |
+
cast,
|
16 |
+
Collection,
|
17 |
+
Dict,
|
18 |
+
Iterator,
|
19 |
+
List,
|
20 |
+
Literal,
|
21 |
+
Mapping,
|
22 |
+
Optional,
|
23 |
+
Sequence,
|
24 |
+
Union,
|
25 |
+
)
|
26 |
+
|
27 |
+
import tiktoken
|
28 |
+
import torch
|
29 |
+
from tiktoken.load import load_tiktoken_bpe
|
30 |
+
|
31 |
+
from torchtitan.components.tokenizer import Tokenizer
|
32 |
+
from torchtitan.config_manager import JobConfig
|
33 |
+
from torchtitan.tools.logging import logger
|
34 |
+
|
35 |
+
IMAGE_TOKEN_ID = 128256
|
36 |
+
IGNORE_INDEX = -100
|
37 |
+
|
38 |
+
|
39 |
+
class TikTokenizer(Tokenizer):
|
40 |
+
"""
|
41 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model_path (str): The path to the Tiktoken model file.
|
45 |
+
"""
|
46 |
+
|
47 |
+
special_tokens: Dict[str, int]
|
48 |
+
|
49 |
+
num_reserved_special_tokens = 256
|
50 |
+
|
51 |
+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
|
52 |
+
|
53 |
+
def __init__(self, model_path: str):
|
54 |
+
super().__init__(model_path)
|
55 |
+
assert os.path.isfile(model_path), model_path
|
56 |
+
|
57 |
+
mergeable_ranks = load_tiktoken_bpe(model_path)
|
58 |
+
num_base_tokens = len(mergeable_ranks)
|
59 |
+
special_tokens = [
|
60 |
+
"<|begin_of_text|>",
|
61 |
+
"<|end_of_text|>",
|
62 |
+
"<|reserved_special_token_0|>",
|
63 |
+
"<|reserved_special_token_1|>",
|
64 |
+
"<|reserved_special_token_2|>",
|
65 |
+
"<|reserved_special_token_3|>",
|
66 |
+
"<|start_header_id|>",
|
67 |
+
"<|end_header_id|>",
|
68 |
+
"<|reserved_special_token_4|>",
|
69 |
+
"<|eot_id|>", # end of turn
|
70 |
+
] + [
|
71 |
+
f"<|reserved_special_token_{i}|>"
|
72 |
+
for i in range(5, self.num_reserved_special_tokens - 5)
|
73 |
+
]
|
74 |
+
self.special_tokens = {
|
75 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
76 |
+
}
|
77 |
+
self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
|
78 |
+
self.model = tiktoken.Encoding(
|
79 |
+
name=Path(model_path).name,
|
80 |
+
pat_str=self.pat_str,
|
81 |
+
mergeable_ranks=mergeable_ranks,
|
82 |
+
special_tokens=self.special_tokens,
|
83 |
+
)
|
84 |
+
|
85 |
+
self._n_words: int = self.model.n_vocab
|
86 |
+
# BOS / EOS token IDs
|
87 |
+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
88 |
+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
89 |
+
self.pad_id: int = -1
|
90 |
+
self.image_id = IMAGE_TOKEN_ID
|
91 |
+
self.stop_tokens = {
|
92 |
+
self.special_tokens["<|end_of_text|>"],
|
93 |
+
self.special_tokens["<|eot_id|>"],
|
94 |
+
}
|
95 |
+
logger.info(
|
96 |
+
f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
|
97 |
+
)
|
98 |
+
|
99 |
+
def encode(
|
100 |
+
self,
|
101 |
+
s: str,
|
102 |
+
*,
|
103 |
+
bos: bool,
|
104 |
+
eos: bool,
|
105 |
+
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
106 |
+
disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
|
107 |
+
) -> List[int]:
|
108 |
+
"""
|
109 |
+
Encodes a string into a list of token IDs.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
s (str): The input string to be encoded.
|
113 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
114 |
+
eos (bool): Whether to append the end-of-sequence token.
|
115 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
116 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
list[int]: A list of token IDs.
|
120 |
+
|
121 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
122 |
+
special tokens. Specifically:
|
123 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
124 |
+
to special tokens to be encoded as natural text (insteading of raising
|
125 |
+
an error).
|
126 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
127 |
+
to special tokens to be encoded as special tokens.
|
128 |
+
"""
|
129 |
+
assert type(s) is str
|
130 |
+
allowed_special = allowed_special or set()
|
131 |
+
disallowed_special = disallowed_special or ()
|
132 |
+
|
133 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
134 |
+
# pyo3_runtime.PanicException.
|
135 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
136 |
+
|
137 |
+
# https://github.com/openai/tiktoken/issues/195
|
138 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
139 |
+
# of max consecutive non-whitespace or whitespace characters.
|
140 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
141 |
+
|
142 |
+
substrs = (
|
143 |
+
substr
|
144 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
145 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
146 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
147 |
+
)
|
148 |
+
)
|
149 |
+
t: List[int] = []
|
150 |
+
for substr in substrs:
|
151 |
+
t.extend(
|
152 |
+
self.model.encode(
|
153 |
+
substr,
|
154 |
+
allowed_special=allowed_special,
|
155 |
+
disallowed_special=disallowed_special,
|
156 |
+
)
|
157 |
+
)
|
158 |
+
if bos:
|
159 |
+
t.insert(0, self.bos_id)
|
160 |
+
if eos:
|
161 |
+
t.append(self.eos_id)
|
162 |
+
return t
|
163 |
+
|
164 |
+
def decode(self, t: Sequence[int]) -> str:
|
165 |
+
"""
|
166 |
+
Decodes a list of token IDs into a string.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
t (List[int]): The list of token IDs to be decoded.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: The decoded string.
|
173 |
+
"""
|
174 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
175 |
+
return self.model.decode(cast(List[int], t))
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def _split_whitespaces_or_nonwhitespaces(
|
179 |
+
s: str, max_consecutive_slice_len: int
|
180 |
+
) -> Iterator[str]:
|
181 |
+
"""
|
182 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
183 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
184 |
+
"""
|
185 |
+
current_slice_len = 0
|
186 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
187 |
+
slice_start = 0
|
188 |
+
|
189 |
+
for i in range(len(s)):
|
190 |
+
is_now_space = s[i].isspace()
|
191 |
+
|
192 |
+
if current_slice_is_space ^ is_now_space:
|
193 |
+
current_slice_len = 1
|
194 |
+
current_slice_is_space = is_now_space
|
195 |
+
else:
|
196 |
+
current_slice_len += 1
|
197 |
+
if current_slice_len > max_consecutive_slice_len:
|
198 |
+
yield s[slice_start:i]
|
199 |
+
slice_start = i
|
200 |
+
current_slice_len = 1
|
201 |
+
yield s[slice_start:]
|
202 |
+
|
203 |
+
def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
|
204 |
+
"""
|
205 |
+
Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
|
206 |
+
"""
|
207 |
+
# TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
|
208 |
+
# For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
|
209 |
+
# & everything else expects `tokens`
|
210 |
+
text = sample["text"]
|
211 |
+
tokens = self.encode(
|
212 |
+
text, bos=True, eos=True, allowed_special=set(["<|image|>"])
|
213 |
+
)
|
214 |
+
input_ids = torch.LongTensor(tokens[:-1])
|
215 |
+
labels = torch.LongTensor(tokens[1:])
|
216 |
+
labels = torch.where(
|
217 |
+
torch.isin(
|
218 |
+
labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
|
219 |
+
),
|
220 |
+
IGNORE_INDEX,
|
221 |
+
labels,
|
222 |
+
)
|
223 |
+
|
224 |
+
assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
|
225 |
+
|
226 |
+
sample.update({"tokens": input_ids, "labels": labels})
|
227 |
+
|
228 |
+
return sample
|
229 |
+
|
230 |
+
|
231 |
+
def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
|
232 |
+
return TikTokenizer(job_config.model.tokenizer_path)
|
torchtitan/experiments/multimodal/utils.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import List, Optional, Set, Tuple, Union
|
13 |
+
from urllib import request
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchvision
|
17 |
+
from torchvision.transforms.v2 import functional as F
|
18 |
+
|
19 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
|
20 |
+
def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
image (torch.Tensor): Input image to crop into tiles.
|
26 |
+
tile_size (int): Size of each tile.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
|
30 |
+
|
31 |
+
Examples:
|
32 |
+
>>> image = torch.rand(3, 200, 300)
|
33 |
+
>>> tiles = tile_crop(image, tile_size=50)
|
34 |
+
>>> tiles.shape # 4x6 = 24 tiles
|
35 |
+
torch.Size([24, 3, 50, 50])
|
36 |
+
|
37 |
+
>>> image = torch.rand(3, 400, 600)
|
38 |
+
>>> tiles = tile_crop(image, tile_size=200)
|
39 |
+
>>> tiles.shape # 2x3 = 6 tiles
|
40 |
+
torch.Size([6, 3, 200, 200])
|
41 |
+
"""
|
42 |
+
|
43 |
+
channel_size, height, width = image.shape
|
44 |
+
|
45 |
+
# assert sizes are divisible
|
46 |
+
assert (
|
47 |
+
height % tile_size == 0 and width % tile_size == 0
|
48 |
+
), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
|
49 |
+
|
50 |
+
# Reshape to split height and width into tile_size blocks
|
51 |
+
tiles_height = height // tile_size
|
52 |
+
tiles_width = width // tile_size
|
53 |
+
|
54 |
+
reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
|
55 |
+
|
56 |
+
# Transpose to bring tiles together
|
57 |
+
# We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
|
58 |
+
transposed = reshaped.permute(1, 3, 0, 2, 4)
|
59 |
+
|
60 |
+
# Flatten the tiles
|
61 |
+
tiles = transposed.contiguous().view(
|
62 |
+
tiles_height * tiles_width, channel_size, tile_size, tile_size
|
63 |
+
)
|
64 |
+
|
65 |
+
return tiles
|
66 |
+
|
67 |
+
|
68 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
69 |
+
def resize_with_pad(
|
70 |
+
image: torch.Tensor,
|
71 |
+
target_size: Tuple[int, int],
|
72 |
+
resample: torchvision.transforms.InterpolationMode,
|
73 |
+
max_size: Optional[int] = None,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Resizes and pads an image to target_size without causing distortion.
|
77 |
+
The user can set max_size to limit upscaling when target_size exceeds image_size.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
81 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
82 |
+
resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
|
83 |
+
Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
|
84 |
+
InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
|
85 |
+
max_size (Optional[int]): The maximum size to upscale the image to.
|
86 |
+
If None, will upscale up to target_size.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: The resized and padded image tensor in the format [..., H, W].
|
90 |
+
|
91 |
+
Examples:
|
92 |
+
|
93 |
+
Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
|
94 |
+
and then padded from (448, 1194) to (448, 1344).
|
95 |
+
|
96 |
+
>>> max_size = None
|
97 |
+
>>> image = torch.rand([3, 300, 800])
|
98 |
+
>>> target_size = (448, 1344)
|
99 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
100 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
101 |
+
|
102 |
+
Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
|
103 |
+
|
104 |
+
>>> max_size = 600
|
105 |
+
>>> image = torch.rand([3, 300, 800])
|
106 |
+
>>> target_size = (448, 1344)
|
107 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
108 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
109 |
+
|
110 |
+
Example 3: The image will be downscaled from (500, 1000) to (224, 448),
|
111 |
+
and padded from (224, 448) to (448, 448).
|
112 |
+
|
113 |
+
>>> max_size = 600
|
114 |
+
>>> image = torch.rand([3, 500, 1000])
|
115 |
+
>>> target_size = (448, 488)
|
116 |
+
>>> resample = torchvision.transforms.InterpolationMode.BILINEAR
|
117 |
+
>>> output = resize_with_pad(image, target_size, resample, max_size)
|
118 |
+
|
119 |
+
"""
|
120 |
+
|
121 |
+
image_height, image_width = image.shape[-2:]
|
122 |
+
image_size = (image_height, image_width)
|
123 |
+
|
124 |
+
# If target_size requires upscaling, we might want to limit the upscaling to max_size
|
125 |
+
if max_size is not None:
|
126 |
+
new_target_height = min(max(image_height, max_size), target_size[0])
|
127 |
+
new_target_width = min(max(image_width, max_size), target_size[1])
|
128 |
+
target_size_resize = (new_target_height, new_target_width)
|
129 |
+
else:
|
130 |
+
target_size_resize = target_size
|
131 |
+
|
132 |
+
# resize to target_size while preserving aspect ratio
|
133 |
+
new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
|
134 |
+
image_size=image_size,
|
135 |
+
target_size=target_size_resize,
|
136 |
+
)
|
137 |
+
|
138 |
+
image = F.resize(
|
139 |
+
inpt=image,
|
140 |
+
size=list(new_size_preserving_aspect_ratio),
|
141 |
+
interpolation=resample,
|
142 |
+
antialias=True,
|
143 |
+
)
|
144 |
+
|
145 |
+
image = _pad_image_top_left(image=image, target_size=target_size)
|
146 |
+
|
147 |
+
return image
|
148 |
+
|
149 |
+
|
150 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
151 |
+
def _pad_image_top_left(
|
152 |
+
image: torch.Tensor,
|
153 |
+
target_size: Tuple[int, int],
|
154 |
+
) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
Places the image at the top left of the canvas and pads with 0 the right and bottom
|
157 |
+
to fit to the target resolution. If target_size < image_size, it will crop the image.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
image (torch.Tensor): The input image tensor in the format [..., H, W].
|
161 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
torch.Tensor: The padded image tensor in the format [..., H, W].
|
165 |
+
"""
|
166 |
+
|
167 |
+
image_size = image.shape[-2:]
|
168 |
+
|
169 |
+
height, width = image_size
|
170 |
+
target_height, target_width = target_size
|
171 |
+
|
172 |
+
pad_x = target_width - width
|
173 |
+
pad_y = target_height - height
|
174 |
+
|
175 |
+
padding = [0, 0, pad_x, pad_y]
|
176 |
+
return F.pad(inpt=image, padding=padding)
|
177 |
+
|
178 |
+
|
179 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
|
180 |
+
def _get_max_res_without_distortion(
|
181 |
+
image_size: Tuple[int, int],
|
182 |
+
target_size: Tuple[int, int],
|
183 |
+
) -> Tuple[int, int]:
|
184 |
+
"""
|
185 |
+
Determines the maximum resolution to which an image can be resized to without distorting its
|
186 |
+
aspect ratio, based on the target resolution.
|
187 |
+
|
188 |
+
For example, if image_size = (200,400) and target_size = (600,800),
|
189 |
+
scale_h = 600/200 = 3
|
190 |
+
scale_w = 800/400 = 2
|
191 |
+
So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
|
192 |
+
|
193 |
+
Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image_size (Tuple[int, int]): The original resolution of the image.
|
197 |
+
target_size (Tuple[int, int]): The desired resolution to fit the image into.
|
198 |
+
Returns:
|
199 |
+
Tuple[int, int]: The optimal dimensions to which the image should be resized.
|
200 |
+
Examples:
|
201 |
+
>>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
|
202 |
+
(133, 200)
|
203 |
+
>>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
|
204 |
+
(450, 337)
|
205 |
+
"""
|
206 |
+
|
207 |
+
original_height, original_width = image_size
|
208 |
+
target_height, target_width = target_size
|
209 |
+
|
210 |
+
scale_w = target_width / original_width
|
211 |
+
scale_h = target_height / original_height
|
212 |
+
|
213 |
+
if scale_w < scale_h:
|
214 |
+
new_width = target_width
|
215 |
+
new_height = min(math.floor(original_height * scale_w), target_height)
|
216 |
+
else:
|
217 |
+
new_height = target_height
|
218 |
+
new_width = min(math.floor(original_width * scale_h), target_width)
|
219 |
+
|
220 |
+
return new_height, new_width
|
221 |
+
|
222 |
+
|
223 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
224 |
+
def _get_factors(n: int) -> Set[int]:
|
225 |
+
"""
|
226 |
+
Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
n (int): The number to find factors for.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
set: A set containing all factors of the number.
|
233 |
+
|
234 |
+
Examples:
|
235 |
+
>>> _get_factors(n=12)
|
236 |
+
{1, 2, 3, 4, 6, 12}
|
237 |
+
"""
|
238 |
+
factors_set = set()
|
239 |
+
|
240 |
+
for i in range(1, int(n**0.5) + 1):
|
241 |
+
if n % i == 0:
|
242 |
+
factors_set.add(i)
|
243 |
+
factors_set.add(n // i)
|
244 |
+
return factors_set
|
245 |
+
|
246 |
+
|
247 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
248 |
+
def get_canvas_best_fit(
|
249 |
+
image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
|
250 |
+
) -> Tuple[int, int]:
|
251 |
+
"""
|
252 |
+
Determines the best canvas possible from a list of possible resolutions to
|
253 |
+
resize an image to, without distortion.
|
254 |
+
|
255 |
+
For each possible resolution, calculates the scaling factors for
|
256 |
+
width and height, and selects the smallest one, which is the limiting side.
|
257 |
+
E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
|
258 |
+
then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
|
259 |
+
|
260 |
+
If there are multiple canvases that satisfy the conditions,
|
261 |
+
we pick the one with the lowest area to minimize padding.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
image (torch.Tensor): The image we want to fit into a canvas.
|
265 |
+
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
266 |
+
row represents a possible canvas.
|
267 |
+
resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
|
268 |
+
If False, pick the canvas that minimizes downscaling, including no downscaling at all.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Tuple[int, int]: The best resolution to fit the image into.
|
272 |
+
|
273 |
+
Examples:
|
274 |
+
>>> image = torch.rand(3, 200, 300)
|
275 |
+
>>> possible_resolutions = torch.tensor([
|
276 |
+
... [224, 672],
|
277 |
+
... [672, 224],
|
278 |
+
... [224, 448],
|
279 |
+
... [448, 224],
|
280 |
+
... [224, 224]
|
281 |
+
... ])
|
282 |
+
>>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
|
283 |
+
(224, 448)
|
284 |
+
|
285 |
+
In the example above, we calculate the scaling factors for each possible resolution
|
286 |
+
|
287 |
+
>>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
288 |
+
>>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
289 |
+
>>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
290 |
+
|
291 |
+
Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
|
292 |
+
|
293 |
+
>>> upscaling_options = torch.tensor([1.1200, 1.1200])
|
294 |
+
>>> selected_scale = torch.tensor(1.1200)
|
295 |
+
|
296 |
+
There are two possible options, so we pick the one with the smallest area
|
297 |
+
|
298 |
+
>>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
|
299 |
+
>>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
|
300 |
+
"""
|
301 |
+
|
302 |
+
original_height, original_width = image.shape[-2:]
|
303 |
+
|
304 |
+
# possible resolutions heights/widths
|
305 |
+
target_heights, target_widths = (
|
306 |
+
possible_resolutions[:, 0],
|
307 |
+
possible_resolutions[:, 1],
|
308 |
+
)
|
309 |
+
|
310 |
+
# scaling factors to resize the image without distortion
|
311 |
+
scale_w = target_widths / original_width
|
312 |
+
scale_h = target_heights / original_height
|
313 |
+
|
314 |
+
# get limiting side scaling -> no distortion
|
315 |
+
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
316 |
+
|
317 |
+
# filter only scales that allow upscaling
|
318 |
+
upscaling_options = scales[scales >= 1]
|
319 |
+
if len(upscaling_options) > 0:
|
320 |
+
if resize_to_max_canvas:
|
321 |
+
selected_scale = torch.max(upscaling_options)
|
322 |
+
else:
|
323 |
+
selected_scale = torch.min(upscaling_options)
|
324 |
+
else:
|
325 |
+
# no upscaling possible,
|
326 |
+
# get the minimum downscaling (max scale for scales<1)
|
327 |
+
downscaling_options = scales[scales < 1]
|
328 |
+
selected_scale = torch.max(downscaling_options)
|
329 |
+
|
330 |
+
# get all resolutions that support this scaling factor,
|
331 |
+
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
332 |
+
chosen_canvas = possible_resolutions[scales == selected_scale]
|
333 |
+
|
334 |
+
# if there are multiple resolutions,
|
335 |
+
# get the one with minimum area to reduce padding
|
336 |
+
if len(chosen_canvas) > 1:
|
337 |
+
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
338 |
+
optimal_idx = torch.argmin(areas)
|
339 |
+
optimal_canvas = chosen_canvas[optimal_idx]
|
340 |
+
else:
|
341 |
+
optimal_canvas = chosen_canvas[0]
|
342 |
+
|
343 |
+
return tuple(optimal_canvas.tolist())
|
344 |
+
|
345 |
+
|
346 |
+
# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
|
347 |
+
def find_supported_resolutions(
|
348 |
+
max_num_tiles: int, tile_size: int
|
349 |
+
) -> List[Tuple[int, int]]:
|
350 |
+
"""
|
351 |
+
Computes all combinations of resolutions, multiple of tile_size,
|
352 |
+
that contain up to max_num_tiles. Useful for when dividing an image into tiles.
|
353 |
+
|
354 |
+
For example, if we want at most 2 tiles per image, then we can support the
|
355 |
+
following resolutions: (1x1, 1x2, 2x1) * tile_size
|
356 |
+
|
357 |
+
Args:
|
358 |
+
max_num_tiles (int): Maximum number of tiles.
|
359 |
+
tile_size (int): Size of the side of the tile.
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
|
363 |
+
|
364 |
+
Examples:
|
365 |
+
|
366 |
+
>>> max_num_tiles = 4
|
367 |
+
>>> tile_size = 224
|
368 |
+
>>> find_supported_resolutions(max_num_tiles, tile_size)
|
369 |
+
[(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
|
370 |
+
"""
|
371 |
+
|
372 |
+
# create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
|
373 |
+
# example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
|
374 |
+
asp_dict = defaultdict(list)
|
375 |
+
for _tile_size in range(max_num_tiles, 0, -1):
|
376 |
+
factors = sorted(_get_factors(_tile_size))
|
377 |
+
asp_ratios = [(factor, _tile_size // factor) for factor in factors]
|
378 |
+
for height, width in asp_ratios:
|
379 |
+
ratio_float = height / width
|
380 |
+
asp_dict[ratio_float].append((height, width))
|
381 |
+
|
382 |
+
# get the resolutions multiplied by the tile_size
|
383 |
+
possible_resolutions = []
|
384 |
+
for ar, resolution in asp_dict.items():
|
385 |
+
for height, width in resolution:
|
386 |
+
possible_resolutions.append((height * tile_size, width * tile_size))
|
387 |
+
|
388 |
+
return possible_resolutions
|
389 |
+
|
390 |
+
|
391 |
+
# NOTE Copied from torchtune.data._utils.py
|
392 |
+
def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
|
393 |
+
"""
|
394 |
+
Convenience method to load an image in torch.Tensor format from a local file path or remote source.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
image_loc (Union[Path, str]): Local file path or remote source pointing to the image
|
398 |
+
which will be loaded in PIL format.
|
399 |
+
|
400 |
+
Note:
|
401 |
+
If loading an image from a remote source, the function expects the URL provided in ``image_loc``
|
402 |
+
to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
|
403 |
+
|
404 |
+
Raises:
|
405 |
+
ValueError: If the image cannot be loaded from remote source, **or**
|
406 |
+
if the image cannot be opened as a :class:`~torch.Tensor`.
|
407 |
+
|
408 |
+
Examples:
|
409 |
+
>>> # Load from remote source
|
410 |
+
>>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
|
411 |
+
|
412 |
+
>>> # Load from local file path
|
413 |
+
>>> image = load_image(Path("/home/user/bird.jpg"))
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
torch.Tensor: The loaded image.
|
417 |
+
"""
|
418 |
+
|
419 |
+
# If pointing to remote source, try to load to local
|
420 |
+
if isinstance(image_loc, str) and image_loc.startswith("http"):
|
421 |
+
try:
|
422 |
+
image_loc = request.urlopen(image_loc).read()
|
423 |
+
image = torchvision.io.decode_image(
|
424 |
+
torch.frombuffer(image_loc, dtype=torch.uint8),
|
425 |
+
mode="RGB",
|
426 |
+
)
|
427 |
+
except Exception as e:
|
428 |
+
raise ValueError("Failed to load remote image as torch.Tensor") from e
|
429 |
+
|
430 |
+
# Open the local image as a Tensor image
|
431 |
+
else:
|
432 |
+
try:
|
433 |
+
image = torchvision.io.decode_image(image_loc, mode="RGB")
|
434 |
+
except Exception as e:
|
435 |
+
raise ValueError("Failed to load local image as torch.Tensor") from e
|
436 |
+
|
437 |
+
return image
|
torchtitan/experiments/simple_fsdp/README.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## SimpleFSDP
|
2 |
+
|
3 |
+
This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.
|
4 |
+
|
5 |
+
### Enable SimpleFSDP Training
|
6 |
+
|
7 |
+
```bash
|
8 |
+
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32
|
9 |
+
```
|
10 |
+
|
11 |
+
Note: The mixed precision training support is on-going. We set `training.mixed_precision_param` to `float32` for now and will remove it once the integration is completed.
|
12 |
+
|
13 |
+
### Composability Support
|
14 |
+
|
15 |
+
Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:
|
16 |
+
|
17 |
+
| Feature | Support |
|
18 |
+
| :--------: | :--------: |
|
19 |
+
|Meta Initialization| ✅ |
|
20 |
+
|Activation Checkpointing| ✅ |
|
21 |
+
|Mixed Precision Training| 🚧 |
|
22 |
+
|Tensor Parallelism| 🚧 |
|
23 |
+
|Context Parallelism| ✅ |
|
24 |
+
|Pipeline Parallelism| ✅ |
|
25 |
+
|Distributed Checkpointing| 🚧 |
|
26 |
+
|Float8 Training| ❌ |
|
27 |
+
|
28 |
+
|
29 |
+
### Citation
|
30 |
+
|
31 |
+
If you find SimpleFSDP useful, please kindly consider citing the following paper:
|
32 |
+
|
33 |
+
```latex
|
34 |
+
@article{zhang2024simplefsdp,
|
35 |
+
title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile},
|
36 |
+
author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco},
|
37 |
+
journal={arXiv preprint arXiv:2411.00284},
|
38 |
+
year={2024}
|
39 |
+
}
|
40 |
+
```
|
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.1 kB). View file
|
|
torchtitan/experiments/simple_fsdp/tests/test_numerics.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.distributed._composable.fsdp import fully_shard
|
10 |
+
|
11 |
+
from torch.testing._internal.common_fsdp import FSDPTest
|
12 |
+
|
13 |
+
from torchtitan.components.loss import cross_entropy_loss
|
14 |
+
from torchtitan.distributed import ParallelDims
|
15 |
+
from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
|
16 |
+
|
17 |
+
|
18 |
+
class TestSimpleFSDP(FSDPTest):
|
19 |
+
def init_test(self):
|
20 |
+
self.optimizer = torch.optim.Adam
|
21 |
+
self.loss_fn = cross_entropy_loss
|
22 |
+
data_parallel_shard_degree = -1
|
23 |
+
if self.mode == "replicate":
|
24 |
+
self.dp_mesh_dim_names = ("dp_replicate",)
|
25 |
+
data_parallel_replicate_degree = self.world_size
|
26 |
+
elif self.mode == "fully_shard":
|
27 |
+
self.dp_mesh_dim_names = ("dp_shard_cp",)
|
28 |
+
data_parallel_replicate_degree = 1
|
29 |
+
elif self.mode == "hybrid_shard":
|
30 |
+
self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
31 |
+
data_parallel_replicate_degree = self.world_size // 2
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported mode {mode}")
|
34 |
+
|
35 |
+
self.parallel_dims = ParallelDims(
|
36 |
+
dp_shard=data_parallel_shard_degree,
|
37 |
+
dp_replicate=data_parallel_replicate_degree,
|
38 |
+
cp=1,
|
39 |
+
tp=1,
|
40 |
+
pp=1,
|
41 |
+
world_size=self.world_size,
|
42 |
+
enable_loss_parallel=True,
|
43 |
+
)
|
44 |
+
self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda")
|
45 |
+
|
46 |
+
def get_input(self):
|
47 |
+
inputs = torch.randn(8, 8).cuda()
|
48 |
+
labels = torch.randn(8, 8).cuda()
|
49 |
+
model = torch.nn.Linear(8, 8)
|
50 |
+
return model, inputs, labels
|
51 |
+
|
52 |
+
def run_fsdp2(self, model, inputs, labels, epoch=20):
|
53 |
+
fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)])
|
54 |
+
optim = self.optimizer(model.parameters(), lr=1e-4)
|
55 |
+
losses = []
|
56 |
+
for _ in range(epoch):
|
57 |
+
optim.zero_grad()
|
58 |
+
out = model(inputs)
|
59 |
+
loss = self.loss_fn(out, labels)
|
60 |
+
loss.backward()
|
61 |
+
optim.step()
|
62 |
+
losses.append(loss)
|
63 |
+
return losses
|
64 |
+
|
65 |
+
def run_simple_fsdp(self, model, inputs, labels, epoch=20):
|
66 |
+
model = data_parallel(
|
67 |
+
model,
|
68 |
+
device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
|
69 |
+
mode=self.mode,
|
70 |
+
)
|
71 |
+
optim = self.optimizer(model.parameters(), lr=1e-4)
|
72 |
+
losses = []
|
73 |
+
for _ in range(epoch):
|
74 |
+
optim.zero_grad()
|
75 |
+
out = model(inputs)
|
76 |
+
loss = self.loss_fn(out, labels)
|
77 |
+
loss.backward()
|
78 |
+
optim.step()
|
79 |
+
losses.append(loss)
|
80 |
+
return losses
|
81 |
+
|
82 |
+
def test_replicate_convergence(self):
|
83 |
+
# unit test for replicate mode
|
84 |
+
self.mode = "replicate"
|
85 |
+
self.init_test()
|
86 |
+
model, inputs, labels = self.get_input()
|
87 |
+
|
88 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
89 |
+
simple_fsdp_replicate_losses = self.run_simple_fsdp(
|
90 |
+
copy.deepcopy(model), inputs, labels
|
91 |
+
)
|
92 |
+
|
93 |
+
for fsdp2_loss, simple_fsdp_replicate_loss in zip(
|
94 |
+
fsdp2_losses, simple_fsdp_replicate_losses
|
95 |
+
):
|
96 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss)
|
97 |
+
|
98 |
+
def test_fullyshard_convergence(self):
|
99 |
+
# unit test for fully_shard mode
|
100 |
+
self.mode = "fully_shard"
|
101 |
+
self.init_test()
|
102 |
+
model, inputs, labels = self.get_input()
|
103 |
+
|
104 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
105 |
+
simple_fsdp_fullyshard_losses = self.run_simple_fsdp(
|
106 |
+
copy.deepcopy(model), inputs, labels
|
107 |
+
)
|
108 |
+
|
109 |
+
for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
|
110 |
+
fsdp2_losses, simple_fsdp_fullyshard_losses
|
111 |
+
):
|
112 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss)
|
113 |
+
|
114 |
+
def test_hybridshard_convergence(self):
|
115 |
+
# unit test for hybrid_shard mode
|
116 |
+
self.mode = "hybrid_shard"
|
117 |
+
self.init_test()
|
118 |
+
model, inputs, labels = self.get_input()
|
119 |
+
|
120 |
+
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
|
121 |
+
simple_fsdp_hybridshard_losses = self.run_simple_fsdp(
|
122 |
+
copy.deepcopy(model), inputs, labels
|
123 |
+
)
|
124 |
+
|
125 |
+
for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
|
126 |
+
fsdp2_losses, simple_fsdp_hybridshard_losses
|
127 |
+
):
|
128 |
+
assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss)
|
torchtitan/models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (186 Bytes). View file
|
|
torchtitan/models/llama3/__pycache__/model.cpython-312.pyc
ADDED
Binary file (25.9 kB). View file
|
|
torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc
ADDED
Binary file (15.1 kB). View file
|
|
torchtitan/models/llama3/parallelize_llama.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.distributed._composable.replicate import replicate
|
15 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
16 |
+
checkpoint_wrapper as ptd_checkpoint_wrapper,
|
17 |
+
)
|
18 |
+
|
19 |
+
from torch.distributed.device_mesh import DeviceMesh
|
20 |
+
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
|
21 |
+
from torch.distributed.tensor import Replicate, Shard
|
22 |
+
from torch.distributed.tensor.parallel import (
|
23 |
+
ColwiseParallel,
|
24 |
+
parallelize_module,
|
25 |
+
PrepareModuleInput,
|
26 |
+
RowwiseParallel,
|
27 |
+
SequenceParallel,
|
28 |
+
)
|
29 |
+
|
30 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
31 |
+
from torchtitan.distributed import ParallelDims
|
32 |
+
from torchtitan.tools.logging import logger
|
33 |
+
|
34 |
+
|
35 |
+
def parallelize_llama(
|
36 |
+
model: nn.Module,
|
37 |
+
world_mesh: DeviceMesh,
|
38 |
+
parallel_dims: ParallelDims,
|
39 |
+
job_config: JobConfig,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
43 |
+
parallelism to the model.
|
44 |
+
|
45 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
46 |
+
the model must fit on GPU or CPU memory.
|
47 |
+
"""
|
48 |
+
|
49 |
+
if parallel_dims.tp_enabled:
|
50 |
+
if (
|
51 |
+
job_config.parallelism.enable_async_tensor_parallel
|
52 |
+
and not job_config.training.compile
|
53 |
+
):
|
54 |
+
raise RuntimeError("Async TP requires --training.compile")
|
55 |
+
|
56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
57 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
58 |
+
"rowwise",
|
59 |
+
"rowwise_with_gw_hp",
|
60 |
+
)
|
61 |
+
|
62 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
63 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
64 |
+
# all-gather happens in high precision.
|
65 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
66 |
+
|
67 |
+
apply_tp(
|
68 |
+
model,
|
69 |
+
world_mesh["tp"],
|
70 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
71 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
72 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
73 |
+
)
|
74 |
+
|
75 |
+
if job_config.model.use_flex_attn:
|
76 |
+
if job_config.activation_checkpoint.mode == "selective":
|
77 |
+
raise ValueError(
|
78 |
+
"FlexAttention is not compatible with selective AC yet. "
|
79 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
80 |
+
)
|
81 |
+
|
82 |
+
if parallel_dims.cp_enabled:
|
83 |
+
raise ValueError(
|
84 |
+
"FlexAttention is not compatible with CP yet. "
|
85 |
+
"We are still working on this."
|
86 |
+
)
|
87 |
+
|
88 |
+
if job_config.activation_checkpoint.mode != "none":
|
89 |
+
apply_ac(model, job_config.activation_checkpoint)
|
90 |
+
|
91 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
92 |
+
if job_config.training.compile:
|
93 |
+
apply_compile(model)
|
94 |
+
|
95 |
+
if (
|
96 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
97 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
98 |
+
if parallel_dims.dp_replicate_enabled:
|
99 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
100 |
+
else:
|
101 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
102 |
+
|
103 |
+
apply_fsdp(
|
104 |
+
model,
|
105 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
106 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
107 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
108 |
+
pp_enabled=parallel_dims.pp_enabled,
|
109 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
110 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
111 |
+
)
|
112 |
+
|
113 |
+
if parallel_dims.dp_replicate_enabled:
|
114 |
+
logger.info("Applied HSDP to the model")
|
115 |
+
else:
|
116 |
+
logger.info("Applied FSDP to the model")
|
117 |
+
|
118 |
+
if parallel_dims.cp_enabled:
|
119 |
+
logger.info("Applied Context Parallel to the model")
|
120 |
+
|
121 |
+
if job_config.training.enable_cpu_offload:
|
122 |
+
logger.info("Applied CPU Offloading to the model")
|
123 |
+
elif parallel_dims.dp_replicate_enabled:
|
124 |
+
if world_mesh.ndim > 1:
|
125 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
126 |
+
apply_ddp(
|
127 |
+
model,
|
128 |
+
world_mesh,
|
129 |
+
enable_compile=job_config.training.compile,
|
130 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
131 |
+
)
|
132 |
+
|
133 |
+
return model
|
134 |
+
|
135 |
+
|
136 |
+
def apply_tp(
|
137 |
+
model: nn.Module,
|
138 |
+
tp_mesh: DeviceMesh,
|
139 |
+
loss_parallel: bool,
|
140 |
+
enable_float8_tensorwise_tp: bool,
|
141 |
+
enable_async_tp: bool,
|
142 |
+
):
|
143 |
+
"""Apply tensor parallelism."""
|
144 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
145 |
+
# transformer block's inputs)
|
146 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
147 |
+
# 3. Parallelize the final linear output layer
|
148 |
+
parallelize_module(
|
149 |
+
model,
|
150 |
+
tp_mesh,
|
151 |
+
{
|
152 |
+
"tok_embeddings": RowwiseParallel(
|
153 |
+
input_layouts=Replicate(),
|
154 |
+
output_layouts=Shard(1),
|
155 |
+
),
|
156 |
+
"norm": SequenceParallel(),
|
157 |
+
"output": ColwiseParallel(
|
158 |
+
input_layouts=Shard(1),
|
159 |
+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
|
160 |
+
use_local_output=not loss_parallel,
|
161 |
+
),
|
162 |
+
},
|
163 |
+
)
|
164 |
+
|
165 |
+
# Parallel styles used for transformer block linear weights and their
|
166 |
+
# inputs may be different for float8 linears with tensorwise scaling.
|
167 |
+
if enable_float8_tensorwise_tp:
|
168 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
169 |
+
from torchao.float8.float8_tensor_parallel import (
|
170 |
+
Float8ColwiseParallel,
|
171 |
+
Float8RowwiseParallel,
|
172 |
+
PrepareFloat8ModuleInput,
|
173 |
+
)
|
174 |
+
|
175 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
176 |
+
Float8RowwiseParallel,
|
177 |
+
Float8ColwiseParallel,
|
178 |
+
PrepareFloat8ModuleInput,
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
182 |
+
RowwiseParallel,
|
183 |
+
ColwiseParallel,
|
184 |
+
PrepareModuleInput,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Apply tensor + sequence parallelism to every transformer block
|
188 |
+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
|
189 |
+
# by folding (and unfolding) the batch dimension and the sequence dimension.
|
190 |
+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
|
191 |
+
for layer_id, transformer_block in model.layers.items():
|
192 |
+
layer_plan = {
|
193 |
+
"attention_norm": SequenceParallel(),
|
194 |
+
"attention": prepare_module_input(
|
195 |
+
input_layouts=(Shard(1), None),
|
196 |
+
desired_input_layouts=(Replicate(), None),
|
197 |
+
),
|
198 |
+
"attention.wq": colwise_parallel(),
|
199 |
+
"attention.wk": colwise_parallel(),
|
200 |
+
"attention.wv": colwise_parallel(),
|
201 |
+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
|
202 |
+
"ffn_norm": SequenceParallel(),
|
203 |
+
"feed_forward": prepare_module_input(
|
204 |
+
input_layouts=(Shard(1),),
|
205 |
+
desired_input_layouts=(Replicate(),),
|
206 |
+
),
|
207 |
+
"feed_forward.w1": colwise_parallel(),
|
208 |
+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
|
209 |
+
"feed_forward.w3": colwise_parallel(),
|
210 |
+
}
|
211 |
+
|
212 |
+
parallelize_module(
|
213 |
+
module=transformer_block,
|
214 |
+
device_mesh=tp_mesh,
|
215 |
+
parallelize_plan=layer_plan,
|
216 |
+
)
|
217 |
+
|
218 |
+
if enable_async_tp:
|
219 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
220 |
+
|
221 |
+
torch._inductor.config._micro_pipeline_tp = True
|
222 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
223 |
+
|
224 |
+
logger.info(
|
225 |
+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
|
226 |
+
"Tensor Parallelism to the model"
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
# for selective op activation checkpointing
|
231 |
+
_save_list = {
|
232 |
+
torch.ops.aten.mm.default,
|
233 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
234 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
235 |
+
# for low precision training, it's useful to always save
|
236 |
+
# the result of max, since the absolute maximum is
|
237 |
+
# used to compute the scaling factor for quantization.
|
238 |
+
torch.ops.aten.max.default,
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
|
243 |
+
valid_ac_modes = ("full", "selective")
|
244 |
+
if ac_config.mode not in valid_ac_modes:
|
245 |
+
raise ValueError(
|
246 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
247 |
+
)
|
248 |
+
|
249 |
+
if ac_config.mode == "full":
|
250 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
251 |
+
|
252 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
253 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
254 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
255 |
+
if not use_op_sac and not use_layer_sac:
|
256 |
+
raise ValueError(
|
257 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
258 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
259 |
+
)
|
260 |
+
if use_op_sac:
|
261 |
+
from torch.utils.checkpoint import (
|
262 |
+
CheckpointPolicy,
|
263 |
+
create_selective_checkpoint_contexts,
|
264 |
+
)
|
265 |
+
|
266 |
+
def _get_custom_policy(meta):
|
267 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
268 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
269 |
+
mm_count_key = f"{mode}_mm_count"
|
270 |
+
if func == torch.ops.aten.mm.default:
|
271 |
+
meta[mm_count_key] += 1
|
272 |
+
# Saves output of all compute ops, except every second mm
|
273 |
+
to_save = func in _save_list and not (
|
274 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
275 |
+
)
|
276 |
+
return (
|
277 |
+
CheckpointPolicy.MUST_SAVE
|
278 |
+
if to_save
|
279 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
280 |
+
)
|
281 |
+
|
282 |
+
return _custom_policy
|
283 |
+
|
284 |
+
def selective_checkpointing_context_fn():
|
285 |
+
meta = defaultdict(int)
|
286 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
287 |
+
|
288 |
+
return ptd_checkpoint_wrapper(
|
289 |
+
module,
|
290 |
+
context_fn=selective_checkpointing_context_fn,
|
291 |
+
preserve_rng_state=False,
|
292 |
+
)
|
293 |
+
elif use_layer_sac:
|
294 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
295 |
+
ac_freq = int(ac_config.selective_ac_option)
|
296 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
297 |
+
ptd_checkpoint_wrapper._count += 1
|
298 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
299 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
300 |
+
else:
|
301 |
+
return module
|
302 |
+
|
303 |
+
|
304 |
+
def apply_ac(model: nn.Module, ac_config):
|
305 |
+
"""Apply activation checkpointing to the model."""
|
306 |
+
for layer_id, transformer_block in model.layers.named_children():
|
307 |
+
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
|
308 |
+
model.layers.register_module(layer_id, transformer_block)
|
309 |
+
|
310 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
311 |
+
|
312 |
+
|
313 |
+
def apply_compile(model: nn.Module):
|
314 |
+
"""
|
315 |
+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
|
316 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
317 |
+
"""
|
318 |
+
for layer_id, transformer_block in model.layers.named_children():
|
319 |
+
transformer_block = torch.compile(transformer_block, fullgraph=True)
|
320 |
+
model.layers.register_module(layer_id, transformer_block)
|
321 |
+
|
322 |
+
logger.info("Compiling each TransformerBlock with torch.compile")
|
323 |
+
|
324 |
+
|
325 |
+
def apply_fsdp(
|
326 |
+
model: nn.Module,
|
327 |
+
dp_mesh: DeviceMesh,
|
328 |
+
param_dtype: torch.dtype,
|
329 |
+
reduce_dtype: torch.dtype,
|
330 |
+
pp_enabled: bool,
|
331 |
+
cpu_offload: bool = False,
|
332 |
+
reshard_after_forward_policy: str = "default",
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
Apply data parallelism (via FSDP2) to the model.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
model (nn.Module): The model to apply data parallelism to.
|
339 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
340 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
341 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
342 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
343 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
344 |
+
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
|
345 |
+
Other options: "never", "always".
|
346 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
347 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
348 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
349 |
+
|
350 |
+
"""
|
351 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
352 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
353 |
+
if cpu_offload:
|
354 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
355 |
+
|
356 |
+
for layer_id, transformer_block in model.layers.items():
|
357 |
+
if reshard_after_forward_policy == "always":
|
358 |
+
reshard_after_forward = True
|
359 |
+
elif reshard_after_forward_policy == "never":
|
360 |
+
reshard_after_forward = False
|
361 |
+
elif reshard_after_forward_policy == "default":
|
362 |
+
if pp_enabled:
|
363 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
364 |
+
# all-gathers, which can be expensive and non-overlapped
|
365 |
+
reshard_after_forward = False
|
366 |
+
else:
|
367 |
+
# As an optimization, do not reshard after forward for the last
|
368 |
+
# transformer block since FSDP would prefetch it immediately
|
369 |
+
reshard_after_forward = int(layer_id) < len(model.layers) - 1
|
370 |
+
else:
|
371 |
+
raise ValueError(
|
372 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
373 |
+
)
|
374 |
+
fully_shard(
|
375 |
+
transformer_block,
|
376 |
+
**fsdp_config,
|
377 |
+
reshard_after_forward=reshard_after_forward,
|
378 |
+
)
|
379 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
380 |
+
|
381 |
+
|
382 |
+
def apply_ddp(
|
383 |
+
model: nn.Module,
|
384 |
+
dp_mesh: DeviceMesh,
|
385 |
+
enable_compile: bool,
|
386 |
+
enable_compiled_autograd: bool,
|
387 |
+
):
|
388 |
+
if enable_compile:
|
389 |
+
if enable_compiled_autograd:
|
390 |
+
torch._dynamo.config.optimize_ddp = (
|
391 |
+
"python_reducer_without_compiled_forward"
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
395 |
+
|
396 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
397 |
+
|
398 |
+
logger.info("Applied DDP to the model")
|
torchtitan/models/llama3/train_configs/llama3_405b.toml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchtitan Config.toml
|
2 |
+
# NOTE: this toml config is a preset for 128 H100 GPUs.
|
3 |
+
|
4 |
+
[job]
|
5 |
+
dump_folder = "./outputs"
|
6 |
+
description = "Llama 3 405B training"
|
7 |
+
|
8 |
+
[profiling]
|
9 |
+
enable_profiling = true
|
10 |
+
save_traces_folder = "profile_trace"
|
11 |
+
profile_freq = 100
|
12 |
+
|
13 |
+
[metrics]
|
14 |
+
log_freq = 10
|
15 |
+
enable_tensorboard = true
|
16 |
+
save_tb_folder = "tb"
|
17 |
+
|
18 |
+
[model]
|
19 |
+
name = "llama3"
|
20 |
+
flavor = "405B"
|
21 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
22 |
+
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
|
23 |
+
converters = "float8"
|
24 |
+
|
25 |
+
[optimizer]
|
26 |
+
name = "AdamW"
|
27 |
+
lr = 8e-5
|
28 |
+
eps = 1e-8
|
29 |
+
|
30 |
+
[lr_scheduler]
|
31 |
+
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
|
32 |
+
|
33 |
+
[training]
|
34 |
+
batch_size = 2
|
35 |
+
seq_len = 8192
|
36 |
+
max_norm = 1.0 # grad norm clipping
|
37 |
+
steps = 3000
|
38 |
+
compile = true
|
39 |
+
dataset = "c4"
|
40 |
+
|
41 |
+
[parallelism]
|
42 |
+
data_parallel_replicate_degree = 1
|
43 |
+
data_parallel_shard_degree = -1
|
44 |
+
tensor_parallel_degree = 8 # 8-way TP
|
45 |
+
enable_async_tensor_parallel = true
|
46 |
+
pipeline_parallel_degree = 1
|
47 |
+
context_parallel_degree = 1
|
48 |
+
|
49 |
+
[checkpoint]
|
50 |
+
enable_checkpoint = false
|
51 |
+
folder = "checkpoint"
|
52 |
+
interval = 500
|
53 |
+
model_weights_only = false
|
54 |
+
export_dtype = "float32"
|
55 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
56 |
+
|
57 |
+
[activation_checkpoint]
|
58 |
+
mode = 'full' # ['none', 'selective', 'full']
|
59 |
+
|
60 |
+
[float8]
|
61 |
+
enable_fsdp_float8_all_gather = true
|
62 |
+
precompute_float8_dynamic_scale_for_fsdp = true
|
63 |
+
filter_fqns = "output"
|
torchtitan/tools/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (5.28 kB). View file
|
|