zaydzuhri commited on
Commit
75edb55
·
verified ·
1 Parent(s): 7641cdc

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. logs/none_99omtdbz/attempt_0/0/stderr.log +0 -0
  2. logs/none_99omtdbz/attempt_0/2/stderr.log +0 -0
  3. logs/none_99omtdbz/attempt_0/5/stdout.log +0 -0
  4. logs/none_99omtdbz/attempt_0/6/stdout.log +0 -0
  5. profile_trace/iteration_17920/rank1_trace.json +0 -0
  6. profile_trace/iteration_17920/rank5_trace.json +0 -0
  7. profile_trace/iteration_21504/rank2_trace.json +0 -0
  8. profile_trace/iteration_24576/rank1_trace.json +0 -0
  9. profile_trace/iteration_24576/rank7_trace.json +0 -0
  10. profile_trace/iteration_33792/rank1_trace.json +0 -0
  11. profile_trace/iteration_39936/rank2_trace.json +0 -0
  12. profile_trace/iteration_39936/rank4_trace.json +0 -0
  13. profile_trace/iteration_512/rank0_trace.json +0 -0
  14. pyproject.toml +43 -0
  15. torchtitan/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  16. torchtitan/components/__pycache__/dataloader.cpython-312.pyc +0 -0
  17. torchtitan/components/__pycache__/ft.cpython-312.pyc +0 -0
  18. torchtitan/components/float8.py +150 -0
  19. torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc +0 -0
  20. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py +159 -0
  21. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  22. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
  23. torchtitan/experiments/flux/flux_argparser.py +42 -0
  24. torchtitan/experiments/flux/loss.py +27 -0
  25. torchtitan/experiments/flux/requirements.txt +2 -0
  26. torchtitan/experiments/flux/tests/test_flux_dataloader.py +103 -0
  27. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  28. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  29. torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py +630 -0
  30. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  31. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py +126 -0
  32. torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
  33. torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
  34. torchtitan/experiments/llama4/model/model.py +466 -0
  35. torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py +536 -0
  36. torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
  37. torchtitan/experiments/multimodal/mm_dataset.py +268 -0
  38. torchtitan/experiments/multimodal/requirements.txt +1 -0
  39. torchtitan/experiments/multimodal/tests/test_utils.py +58 -0
  40. torchtitan/experiments/multimodal/tokenizer/tiktoken.py +232 -0
  41. torchtitan/experiments/multimodal/utils.py +437 -0
  42. torchtitan/experiments/simple_fsdp/README.md +40 -0
  43. torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc +0 -0
  44. torchtitan/experiments/simple_fsdp/tests/test_numerics.py +128 -0
  45. torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
  46. torchtitan/models/llama3/__pycache__/model.cpython-312.pyc +0 -0
  47. torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc +0 -0
  48. torchtitan/models/llama3/parallelize_llama.py +398 -0
  49. torchtitan/models/llama3/train_configs/llama3_405b.toml +63 -0
  50. torchtitan/tools/__pycache__/utils.cpython-312.pyc +0 -0
logs/none_99omtdbz/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_99omtdbz/attempt_0/2/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_99omtdbz/attempt_0/5/stdout.log ADDED
File without changes
logs/none_99omtdbz/attempt_0/6/stdout.log ADDED
File without changes
profile_trace/iteration_17920/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_17920/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_21504/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_24576/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_24576/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33792/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_39936/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_39936/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_512/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "flame"
3
+ dynamic = ["version"]
4
+ description = "A minimal training framework for scaling FLA models"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Songlin Yang", email = "yangsl66@mit.edu" },
8
+ { name = "Yu Zhang", email = "yzhang.cs@outlook.com" },
9
+ ]
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ requires-python = ">=3.10"
18
+ dependencies = [
19
+ 'torch==2.6',
20
+ 'torchdata',
21
+ 'transformers==4.51.3',
22
+ 'triton>=3.0',
23
+ 'datasets>=3.3.0',
24
+ 'einops',
25
+ 'ninja',
26
+ 'wandb',
27
+ 'tiktoken',
28
+ 'tensorboard',
29
+ 'python-dotenv'
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = ["pytest"]
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/fla-org/flame"
37
+
38
+ [build-system]
39
+ requires = ["setuptools>=45", "wheel", "ninja", "torch"]
40
+
41
+ [tool.isort]
42
+ line_length = 127
43
+ multi_line_output = 3
torchtitan/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
torchtitan/components/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
torchtitan/components/__pycache__/ft.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
torchtitan/components/float8.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # [Note] Getting the 'torchao' package:
8
+ # This script requires the 'torchao' package to function correctly.
9
+ # Please ensure you have this package installed from the appropriate repository.
10
+ # You can obtain it from https://github.com/pytorch/ao by following the
11
+ # installation instructions.
12
+
13
+ # Note: Performance
14
+ # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.distributed import ParallelDims
21
+ from torchtitan.protocols.model_converter import (
22
+ ModelConverter,
23
+ register_model_converter,
24
+ )
25
+ from torchtitan.tools.logging import logger
26
+
27
+
28
+ def _is_sm89_or_later():
29
+ # Float8 is only supported on SM89 or later (H100+ GPUs)
30
+ return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
31
+
32
+
33
+ class Float8Converter(ModelConverter):
34
+ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
35
+ self.enabled = False
36
+
37
+ float8_config = job_config.float8
38
+ if not _is_sm89_or_later():
39
+ logger.warning(
40
+ "Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
41
+ )
42
+ return
43
+ try:
44
+ from torchao.float8 import Float8LinearConfig
45
+ except ImportError as e:
46
+ raise ImportError(
47
+ "torchao is not installed. Please install it to use float8 linear layers."
48
+ ) from e
49
+
50
+ if float8_config.recipe_name is not None and not hasattr(
51
+ Float8LinearConfig, "from_recipe_name"
52
+ ):
53
+ logger.warning(
54
+ "Failed to swap to Float8Linear with recipe lookup because the torchao version "
55
+ "is too old, please install torchao v0.9.0 or later and try again",
56
+ )
57
+ return
58
+
59
+ self.enabled = True
60
+ self.filter_fqns = float8_config.filter_fqns
61
+
62
+ if float8_config.recipe_name is not None:
63
+ assert (
64
+ not float8_config.enable_fsdp_float8_all_gather
65
+ ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
66
+ assert (
67
+ not float8_config.force_recompute_fp8_weight_in_bwd
68
+ ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
69
+ self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
70
+ self.precompute_scale = False
71
+ logger.info(
72
+ f"Float8 training active with recipe {float8_config.recipe_name}"
73
+ )
74
+
75
+ else:
76
+ # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
77
+ enable_fsdp_float8_all_gather = (
78
+ parallel_dims.dp_shard_enabled
79
+ and float8_config.enable_fsdp_float8_all_gather
80
+ )
81
+ self.config = Float8LinearConfig(
82
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
83
+ force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
84
+ )
85
+ # for precompute_float8_dynamic_scale_for_fsdp
86
+ self.precompute_scale = (
87
+ enable_fsdp_float8_all_gather
88
+ and float8_config.precompute_float8_dynamic_scale_for_fsdp
89
+ )
90
+ logger.info("Float8 tensorwise scaled training active")
91
+
92
+ def convert(self, model: nn.Module):
93
+ return self.convert_to_float8_training(model)
94
+
95
+ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
96
+ return self.precompute_float8_dynamic_scale_for_fsdp(model)
97
+
98
+ def convert_to_float8_training(self, model: nn.Module):
99
+ """
100
+ This function converts the linear layers of `model` to `Float8Linear`.
101
+ Note that today, only dynamic tensor scaling (the default) is supported.
102
+ This will mutate the model inplace.
103
+ """
104
+ if not self.enabled:
105
+ return
106
+
107
+ from torchao.float8 import convert_to_float8_training
108
+
109
+ # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
110
+ convert_to_float8_training(
111
+ model,
112
+ config=self.config,
113
+ module_filter_fn=self._module_filter_fn,
114
+ )
115
+ logger.info(
116
+ "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
117
+ f"{self.config.enable_fsdp_float8_all_gather}"
118
+ )
119
+
120
+ def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
121
+ if not isinstance(mod, nn.Linear):
122
+ return False
123
+
124
+ # All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
125
+ dims_multiples_of_16 = (
126
+ mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
127
+ )
128
+
129
+ # If the fqn matches any filtered fqn, then we should not convert this module.
130
+ is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
131
+
132
+ return dims_multiples_of_16 and not is_filtered_fqn
133
+
134
+ def precompute_float8_dynamic_scale_for_fsdp(
135
+ self, model: nn.Module | list[nn.Module]
136
+ ):
137
+ if not self.enabled:
138
+ return
139
+
140
+ if not self.precompute_scale:
141
+ return
142
+
143
+ from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
144
+
145
+ models = [model] if isinstance(model, nn.Module) else model
146
+ for m in models:
147
+ precompute_float8_dynamic_scale_for_fsdp(m)
148
+
149
+
150
+ register_model_converter(Float8Converter, "float8")
torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc ADDED
Binary file (6.1 kB). View file
 
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from .triton_utils import get_flat_bid, get_flat_tid
11
+
12
+
13
+ @triton.jit
14
+ def send_signal(addrs, sem: tl.constexpr):
15
+ if sem == "relaxed":
16
+ tl.inline_asm_elementwise(
17
+ """
18
+ {
19
+ .reg .u32 %tmp32_<1>;
20
+ .reg .pred %p<1>;
21
+
22
+ send_signal:
23
+ atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1;
24
+ setp.eq.u32 %p0, %tmp32_0, 0;
25
+ @!%p0 bra send_signal;
26
+ }
27
+ """,
28
+ "=r, l",
29
+ [addrs],
30
+ dtype=tl.int32,
31
+ is_pure=False,
32
+ pack=1,
33
+ )
34
+ elif sem == "acq_rel":
35
+ tl.inline_asm_elementwise(
36
+ """
37
+ {
38
+ .reg .u32 %tmp32_<1>;
39
+ .reg .pred %p<1>;
40
+
41
+ send_signal:
42
+ atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1;
43
+ setp.eq.u32 %p0, %tmp32_0, 0;
44
+ @!%p0 bra send_signal;
45
+ }
46
+ """,
47
+ "=r, l",
48
+ [addrs],
49
+ dtype=tl.int32,
50
+ is_pure=False,
51
+ pack=1,
52
+ )
53
+ else:
54
+ raise RuntimeError(f"Unrecognized sem: {sem}")
55
+
56
+
57
+ @triton.jit
58
+ def wait_signal(addrs, sem: tl.constexpr):
59
+ if sem == "relaxed":
60
+ tl.inline_asm_elementwise(
61
+ """
62
+ {
63
+ .reg .u32 %tmp32_<1>;
64
+ .reg .pred %p<1>;
65
+
66
+ wait_signal:
67
+ atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0;
68
+ setp.eq.u32 %p0, %tmp32_0, 1;
69
+ @!%p0 bra wait_signal;
70
+ }
71
+ """,
72
+ "=r, l",
73
+ [addrs],
74
+ dtype=tl.int32,
75
+ is_pure=False,
76
+ pack=1,
77
+ )
78
+ elif sem == "acq_rel":
79
+ tl.inline_asm_elementwise(
80
+ """
81
+ {
82
+ .reg .u32 %tmp32_<1>;
83
+ .reg .pred %p<1>;
84
+
85
+ wait_signal:
86
+ atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0;
87
+ setp.eq.u32 %p0, %tmp32_0, 1;
88
+ @!%p0 bra wait_signal;
89
+ }
90
+ """,
91
+ "=r, l",
92
+ [addrs],
93
+ dtype=tl.int32,
94
+ is_pure=False,
95
+ pack=1,
96
+ )
97
+ else:
98
+ raise RuntimeError(f"Unrecognized sem: {sem}")
99
+
100
+
101
+ @triton.jit
102
+ def blockwise_barrier(
103
+ signal_pad_ptrs,
104
+ block_id,
105
+ rank: tl.constexpr,
106
+ world_size: tl.constexpr,
107
+ sem: tl.constexpr,
108
+ ):
109
+ """
110
+ Synchronizes blocks with matching block_id across participating devices.
111
+
112
+ Note: the function itself is not a system level barrier/fence. It is a
113
+ building block for expressing different synchronization patterns.
114
+
115
+ Pattern 0: Ensures that all writes to symm_mem buffers from previous
116
+ kernels across all devices are visible to the current kernel:
117
+
118
+ blockwise_barrier(..., sem="relaxed")
119
+ sync_threads()
120
+
121
+ Pattern 1: Ensures that all writes to symm_mem buffers from the current
122
+ block are visible to all remote blocks with matching blockIdx:
123
+
124
+ sync_threads()
125
+ blockwise_barrier(..., sem="acq_rel")
126
+ sync_threads()
127
+
128
+ Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
129
+ for writing by subsequent kernels across all devices.
130
+
131
+ sync_threads()
132
+ blockwise_barrier(..., sem="relaxed")
133
+
134
+ CUDA graph friendliness:
135
+
136
+ This barrier operates through atomic operations on a zero-filled signal
137
+ pad, which resets to a zero-filled state after each successful
138
+ synchronization. This design eliminates the need for incrementing a
139
+ flag from host.
140
+ """
141
+ if block_id is None:
142
+ block_id = get_flat_bid()
143
+ flat_tid = get_flat_tid()
144
+
145
+ remote_ranks = tl.arange(0, world_size)
146
+ signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64))
147
+ remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to(
148
+ tl.pointer_type(tl.uint32)
149
+ )
150
+ send_addrs = remote_signal_pad_addrs + block_id * world_size + rank
151
+
152
+ local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to(
153
+ tl.pointer_type(tl.uint32)
154
+ )
155
+ wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks
156
+
157
+ if flat_tid < world_size:
158
+ send_signal(send_addrs, sem)
159
+ wait_signal(wait_addrs, sem)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def get_tid():
13
+ return tl.inline_asm_elementwise(
14
+ """
15
+ mov.u32 $0, %tid.x;
16
+ mov.u32 $1, %tid.y;
17
+ mov.u32 $2, %tid.z;
18
+ """,
19
+ "=r,=r,=r",
20
+ [],
21
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
22
+ is_pure=True,
23
+ pack=1,
24
+ )
25
+
26
+
27
+ @triton.jit
28
+ def get_ntid():
29
+ return tl.inline_asm_elementwise(
30
+ """
31
+ mov.u32 $0, %ntid.x;
32
+ mov.u32 $1, %ntid.y;
33
+ mov.u32 $2, %ntid.z;
34
+ """,
35
+ "=r,=r,=r",
36
+ [],
37
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
38
+ is_pure=True,
39
+ pack=1,
40
+ )
41
+
42
+
43
+ @triton.jit
44
+ def get_flat_tid():
45
+ tid_x, tid_y, tid_z = get_tid()
46
+ ntid_x, ntid_y, _ = get_ntid()
47
+ return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
48
+
49
+
50
+ @triton.jit
51
+ def get_flat_bid():
52
+ return (
53
+ tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
54
+ + tl.program_id(1) * tl.num_programs(0)
55
+ + tl.program_id(0)
56
+ )
57
+
58
+
59
+ @triton.jit
60
+ def sync_threads():
61
+ tl.inline_asm_elementwise(
62
+ "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
63
+ )
torchtitan/experiments/flux/flux_argparser.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+
11
+
12
+ def extend_parser(parser: argparse.ArgumentParser) -> None:
13
+ parser.add_argument(
14
+ "--training.guidance",
15
+ type=float,
16
+ default=3.5,
17
+ help="guidance value used for guidance distillation",
18
+ )
19
+ parser.add_argument(
20
+ "--encoder.t5_encoder",
21
+ type=str,
22
+ default="google/t5-v1_1-small",
23
+ help="T5 encoder to use, HuggingFace model name.",
24
+ )
25
+ parser.add_argument(
26
+ "--encoder.clip_encoder",
27
+ type=str,
28
+ default="openai/clip-vit-large-patch14",
29
+ help="Clip encoder to use, HuggingFace model name.",
30
+ )
31
+ parser.add_argument(
32
+ "--encoder.encoder_dtype",
33
+ type=torch.dtype,
34
+ default=torch.bfloat16,
35
+ help="Which dtype to load for autoencoder. ",
36
+ )
37
+ parser.add_argument(
38
+ "--encoder.max_t5_encoding_len",
39
+ type=int,
40
+ default=512,
41
+ help="Maximum length of the T5 encoding.",
42
+ )
torchtitan/experiments/flux/loss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, TypeAlias
8
+
9
+ import torch
10
+
11
+ from torchtitan.config_manager import JobConfig
12
+ from torchtitan.tools.logging import logger
13
+
14
+ LossFunction: TypeAlias = Callable[..., torch.Tensor]
15
+
16
+
17
+ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
18
+ """Common MSE loss function for Transformer models training."""
19
+ return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
20
+
21
+
22
+ def build_mse_loss(job_config: JobConfig):
23
+ loss_fn = mse_loss
24
+ if job_config.training.compile:
25
+ logger.info("Compiling the loss function with torch.compile")
26
+ loss_fn = torch.compile(loss_fn)
27
+ return loss_fn
torchtitan/experiments/flux/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ einops
torchtitan/experiments/flux/tests/test_flux_dataloader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from torchtitan.config_manager import JobConfig
10
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
11
+ from torchtitan.tools.profiling import (
12
+ maybe_enable_memory_snapshot,
13
+ maybe_enable_profiling,
14
+ )
15
+
16
+
17
+ class TestFluxDataLoader:
18
+ def test_flux_dataloader(self):
19
+ dataset_name = "cc12m"
20
+ batch_size = 32
21
+ world_size = 4
22
+ rank = 0
23
+
24
+ num_steps = 10
25
+
26
+ path = "torchtitan.experiments.flux.flux_argparser"
27
+ sys.argv.append(f"--experimental.custom_args_module={path}")
28
+ config = JobConfig()
29
+ config.maybe_add_custom_args()
30
+ config.parse_args(
31
+ [
32
+ # Profiling options
33
+ # "--profiling.enable_profiling",
34
+ # "--profiling.profile_freq",
35
+ # "5",
36
+ # "--profiling.enable_memory_snapshot",
37
+ # "--profiling.save_memory_snapshot_folder",
38
+ # "memory_snapshot_flux",
39
+ "--training.dataset",
40
+ dataset_name,
41
+ "--training.batch_size",
42
+ str(batch_size),
43
+ "--encoder.t5_encoder",
44
+ "google/t5-v1_1-small",
45
+ "--encoder.clip_encoder",
46
+ "openai/clip-vit-large-patch14",
47
+ "--encoder.max_t5_encoding_len",
48
+ "512",
49
+ ]
50
+ )
51
+
52
+ with maybe_enable_profiling(
53
+ config, global_step=0
54
+ ) as torch_profiler, maybe_enable_memory_snapshot(
55
+ config, global_step=0
56
+ ) as memory_profiler:
57
+ dl = self._build_dataloader(
58
+ config,
59
+ world_size,
60
+ rank,
61
+ )
62
+ dl = iter(dl)
63
+
64
+ for i in range(0, num_steps):
65
+ input_data, labels = next(dl)
66
+ print(f"Step {i} image size: {labels.shape}")
67
+ if torch_profiler:
68
+ torch_profiler.step()
69
+ if memory_profiler:
70
+ memory_profiler.step()
71
+
72
+ print(len(input_data["clip_tokens"]))
73
+ for k, v in input_data.items():
74
+ print(f"Step {i} {k} value: {type(v), v.shape}")
75
+
76
+ assert len(input_data) == 2 # (clip_encodings, t5_encodings)
77
+ assert labels.shape == (batch_size, 3, 256, 256)
78
+ # assert input_data["clip_tokens"].shape[0] == batch_size
79
+ # assert input_data["t5_tokens"].shape == (batch_size, 512, 512)
80
+
81
+ if torch_profiler:
82
+ torch_profiler.step()
83
+ if memory_profiler:
84
+ memory_profiler.step(exit_ctx=True)
85
+
86
+ def test_preprocess(self):
87
+ # TODO
88
+ pass
89
+
90
+ def _build_dataloader(
91
+ self,
92
+ job_config,
93
+ world_size,
94
+ rank,
95
+ ):
96
+
97
+ return build_flux_dataloader(
98
+ dp_world_size=world_size,
99
+ dp_rank=rank,
100
+ job_config=job_config,
101
+ tokenizer=None,
102
+ infinite=False,
103
+ )
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation
11
+
12
+ import argparse
13
+ import logging
14
+ import time
15
+
16
+ # from typing import Dict, List, Optional, Tuple
17
+
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import torch
21
+ import triton
22
+
23
+ # import triton.language as tl
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+ # Try to import the optimized implementations
31
+ try:
32
+ from torchao_pr.mg_grouped_gemm import grouped_gemm_forward
33
+
34
+ except ImportError:
35
+ logging.error(
36
+ "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path."
37
+ )
38
+ raise
39
+
40
+
41
+ def compute_reference_forward(x, w, m_sizes):
42
+ """
43
+ Reference PyTorch implementation of M*G grouped GEMM forward pass.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor of shape (M, K)
47
+ w (torch.Tensor): Weight tensor of shape (N, K)
48
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
49
+
50
+ Returns:
51
+ torch.Tensor: Output tensor of shape (M, N)
52
+ """
53
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
54
+
55
+ m_start = 0
56
+ for g in range(len(m_sizes)):
57
+ m_size = m_sizes[g].item()
58
+ if m_size > 0:
59
+ m_end = m_start + m_size
60
+
61
+ # Extract group input
62
+ x_g = x[m_start:m_end]
63
+
64
+ # Compute group output
65
+ y_g = torch.matmul(x_g, w.T)
66
+
67
+ # Store result
68
+ result[m_start:m_end] = y_g
69
+
70
+ # Update start index
71
+ m_start = m_end
72
+
73
+ return result
74
+
75
+
76
+ @triton.testing.perf_report(
77
+ triton.testing.Benchmark(
78
+ x_names=["N"], # We'll vary the output dimension
79
+ x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test
80
+ # x_vals=[8192, 16384],
81
+ line_arg="provider", # We'll compare different providers
82
+ line_vals=["pytorch_reference", "M*G grouped GEMM"],
83
+ line_names=["PyTorch Reference", "M*G grouped Kernel"],
84
+ styles=[("blue", "-"), ("red", "-")],
85
+ ylabel="TFLOPS", # We'll measure TFLOPS
86
+ plot_name="mg_grouped_gemm_comparison",
87
+ args={
88
+ "M": 8192, # Batch dimension, fixed for all tests
89
+ "K": 7168, # Hidden dimension, fixed for all tests
90
+ "G": 8, # Number of groups
91
+ "dtype": torch.float16,
92
+ "device": "cuda",
93
+ },
94
+ )
95
+ )
96
+ def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
97
+ """
98
+ Benchmark the forward pass of the grouped GEMM implementation.
99
+
100
+ Args:
101
+ M (int): Total batch size dimension
102
+ K (int): Hidden dimension
103
+ N (int): Output dimension
104
+ G (int): Number of groups
105
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
106
+ dtype (torch.dtype): Data type to use
107
+ device (str): Device to use
108
+
109
+ Returns:
110
+ float: Performance in TFLOPS
111
+ """
112
+ # Create group sizes for M dimension (balanced across groups)
113
+ base_size = M // G
114
+ remainder = M % G
115
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
116
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
117
+
118
+ print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}")
119
+
120
+ # Create input and weight tensors
121
+ x = torch.randn(M, K, dtype=dtype, device=device)
122
+ w = torch.randn(N, K, dtype=dtype, device=device)
123
+
124
+ # Pre-compute for PyTorch reference to ensure fair comparison
125
+ if provider == "pytorch_reference":
126
+ # Warmup
127
+ torch.cuda.synchronize()
128
+ compute_reference_forward(x, w, m_sizes)
129
+ torch.cuda.synchronize()
130
+
131
+ # Benchmark
132
+ start_time = time.time()
133
+ for _ in range(10): # Average over 10 runs
134
+ compute_reference_forward(x, w, m_sizes)
135
+ torch.cuda.synchronize()
136
+ end_time = time.time()
137
+ else: # Optimized kernel
138
+ # Warmup
139
+ torch.cuda.synchronize()
140
+ grouped_gemm_forward(x, w, m_sizes)
141
+ torch.cuda.synchronize()
142
+
143
+ # Benchmark
144
+ start_time = time.time()
145
+ for _ in range(10): # Average over 10 runs
146
+ grouped_gemm_forward(x, w, m_sizes)
147
+ torch.cuda.synchronize()
148
+ end_time = time.time()
149
+
150
+ # Calculate FLOPs
151
+ # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs)
152
+ flops = 2 * M * N * K
153
+
154
+ # Convert to TFLOPS (tera-FLOPS)
155
+ avg_time = (end_time - start_time) / 10 # Average time per run
156
+ tflops = flops / avg_time / 1e12
157
+
158
+ return tflops
159
+
160
+
161
+ @triton.testing.perf_report(
162
+ triton.testing.Benchmark(
163
+ x_names=["G"], # We'll vary the number of groups
164
+ x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test
165
+ line_arg="provider", # We'll compare different providers
166
+ line_vals=["pytorch_reference", "optimized_kernel"],
167
+ line_names=["PyTorch Reference", "Optimized Kernel"],
168
+ styles=[("blue", "-"), ("red", "-")],
169
+ ylabel="TFLOPS", # We'll measure TFLOPS
170
+ plot_name="mg_grouped_gemm_group_scaling",
171
+ args={
172
+ "M": 8192, # Batch dimension, fixed for all tests
173
+ "K": 4096, # Hidden dimension, fixed for all tests
174
+ "N": 8192, # Output dimension, fixed for all tests
175
+ "dtype": torch.float16,
176
+ "device": "cuda",
177
+ },
178
+ )
179
+ )
180
+ def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"):
181
+ """
182
+ Benchmark how performance scales with number of groups.
183
+
184
+ Args:
185
+ M (int): Total batch size dimension
186
+ K (int): Hidden dimension
187
+ N (int): Output dimension
188
+ G (int): Number of groups
189
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
190
+ dtype (torch.dtype): Data type to use
191
+ device (str): Device to use
192
+
193
+ Returns:
194
+ float: Performance in TFLOPS
195
+ """
196
+ # Create group sizes for M dimension (balanced across groups)
197
+ base_size = M // G
198
+ remainder = M % G
199
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
200
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
201
+
202
+ # Create input and weight tensors
203
+ x = torch.randn(M, K, dtype=dtype, device=device)
204
+ w = torch.randn(N, K, dtype=dtype, device=device)
205
+
206
+ # Benchmark logic - same as previous function
207
+ if provider == "pytorch_reference":
208
+ torch.cuda.synchronize()
209
+ compute_reference_forward(x, w, m_sizes)
210
+ torch.cuda.synchronize()
211
+
212
+ start_time = time.time()
213
+ for _ in range(10):
214
+ compute_reference_forward(x, w, m_sizes)
215
+ torch.cuda.synchronize()
216
+ end_time = time.time()
217
+ else:
218
+ torch.cuda.synchronize()
219
+ grouped_gemm_forward(x, w, m_sizes)
220
+ torch.cuda.synchronize()
221
+
222
+ start_time = time.time()
223
+ for _ in range(10):
224
+ grouped_gemm_forward(x, w, m_sizes)
225
+ torch.cuda.synchronize()
226
+ end_time = time.time()
227
+
228
+ # Calculate FLOPs and TFLOPS
229
+ flops = 2 * M * N * K
230
+ avg_time = (end_time - start_time) / 10
231
+ tflops = flops / avg_time / 1e12
232
+
233
+ return tflops
234
+
235
+
236
+ @triton.testing.perf_report(
237
+ triton.testing.Benchmark(
238
+ x_names=["group_balance"], # We'll vary the group balance factor
239
+ x_vals=[
240
+ 0.0,
241
+ 0.25,
242
+ 0.5,
243
+ 0.75,
244
+ 0.9,
245
+ ], # Different imbalance factors (0 = balanced, 1 = max imbalance)
246
+ line_arg="provider", # We'll compare different providers
247
+ line_vals=["pytorch_reference", "optimized_kernel"],
248
+ line_names=["PyTorch Reference", "Optimized Kernel"],
249
+ styles=[("blue", "-"), ("red", "-")],
250
+ ylabel="TFLOPS", # We'll measure TFLOPS
251
+ plot_name="mg_grouped_gemm_imbalance",
252
+ args={
253
+ "M": 8192, # Batch dimension, fixed for all tests
254
+ "K": 4096, # Hidden dimension, fixed for all tests
255
+ "N": 8192, # Output dimension, fixed for all tests
256
+ "G": 4, # Number of groups
257
+ "dtype": torch.float16,
258
+ "device": "cuda",
259
+ },
260
+ )
261
+ )
262
+ def benchmark_imbalance(
263
+ M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda"
264
+ ):
265
+ """
266
+ Benchmark how performance is affected by imbalanced group sizes.
267
+
268
+ Args:
269
+ M (int): Total batch size dimension
270
+ K (int): Hidden dimension
271
+ N (int): Output dimension
272
+ G (int): Number of groups
273
+ group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance)
274
+ provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel')
275
+ dtype (torch.dtype): Data type to use
276
+ device (str): Device to use
277
+
278
+ Returns:
279
+ float: Performance in TFLOPS
280
+ """
281
+ # Create imbalanced group sizes for M dimension
282
+ if group_balance == 0:
283
+ # Balanced case
284
+ base_size = M // G
285
+ remainder = M % G
286
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
287
+ else:
288
+ # Imbalanced case
289
+ # First group gets more elements, last group gets fewer
290
+ # The imbalance is controlled by the group_balance factor
291
+ remaining = M
292
+ M_sizes = []
293
+ for g in range(G):
294
+ # Interpolate from balanced to imbalanced based on group_balance
295
+ # For balanced (group_balance=0), each group gets M/G
296
+ # For imbalanced (group_balance=1), first group gets much more than last group
297
+ balanced_size = remaining // (G - g)
298
+
299
+ # Adjusting size based on position and imbalance factor
300
+ # First groups get more, last groups get less
301
+ if g < G // 2:
302
+ # First half of groups get more
303
+ adjustment = int(balanced_size * group_balance * (1 - g / (G - 1)))
304
+ size = balanced_size + adjustment
305
+ else:
306
+ # Second half of groups get less
307
+ adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5))
308
+ size = balanced_size - adjustment
309
+
310
+ # Ensure we don't go below 1 or take more than remaining
311
+ size = max(1, min(size, remaining))
312
+ M_sizes.append(size)
313
+ remaining -= size
314
+
315
+ # Handle any remaining elements
316
+ if remaining > 0:
317
+ M_sizes[-1] += remaining
318
+
319
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
320
+
321
+ # Create input and weight tensors
322
+ x = torch.randn(M, K, dtype=dtype, device=device)
323
+ w = torch.randn(N, K, dtype=dtype, device=device)
324
+
325
+ # Benchmark logic
326
+ if provider == "pytorch_reference":
327
+ torch.cuda.synchronize()
328
+ compute_reference_forward(x, w, m_sizes)
329
+ torch.cuda.synchronize()
330
+
331
+ start_time = time.time()
332
+ for _ in range(10):
333
+ compute_reference_forward(x, w, m_sizes)
334
+ torch.cuda.synchronize()
335
+ end_time = time.time()
336
+ else:
337
+ torch.cuda.synchronize()
338
+ grouped_gemm_forward(x, w, m_sizes)
339
+ torch.cuda.synchronize()
340
+
341
+ start_time = time.time()
342
+ for _ in range(10):
343
+ grouped_gemm_forward(x, w, m_sizes)
344
+ torch.cuda.synchronize()
345
+ end_time = time.time()
346
+
347
+ # Calculate FLOPs and TFLOPS
348
+ flops = 2 * M * N * K
349
+ avg_time = (end_time - start_time) / 10
350
+ tflops = flops / avg_time / 1e12
351
+
352
+ return tflops
353
+
354
+
355
+ def benchmark_model_configs():
356
+ """
357
+ Benchmark common model configurations used in DeepSeek-like models.
358
+ """
359
+ # Model configurations: (M, K, N, G)
360
+ configs = [
361
+ (8192, 7168, 4096, 4), # Config 1
362
+ (8192, 2048, 7168, 4), # Config 2
363
+ (4096, 7168, 4096, 8), # Config 3
364
+ (4096, 2048, 7168, 8), # Config 4
365
+ ]
366
+
367
+ results = []
368
+
369
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
370
+ dtype = torch.float16
371
+
372
+ for config_idx, (M, K, N, G) in enumerate(configs):
373
+ logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====")
374
+ logging.info(f"M={M}, K={K}, N={N}, G={G}")
375
+
376
+ # Create group sizes for M dimension
377
+ base_size = M // G
378
+ remainder = M % G
379
+ M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)]
380
+ m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32)
381
+
382
+ # Create tensors
383
+ x = torch.randn(M, K, dtype=dtype, device=device)
384
+ w = torch.randn(N, K, dtype=dtype, device=device)
385
+
386
+ # Benchmark PyTorch reference
387
+ torch.cuda.synchronize()
388
+ compute_reference_forward(x, w, m_sizes) # Warmup
389
+ torch.cuda.synchronize()
390
+
391
+ logging.info("Benchmarking PyTorch reference...")
392
+ torch.cuda.reset_peak_memory_stats()
393
+ start_time = time.time()
394
+ for _ in range(10):
395
+ compute_reference_forward(x, w, m_sizes)
396
+ torch.cuda.synchronize()
397
+ end_time = time.time()
398
+ pt_time = (end_time - start_time) / 10
399
+ pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
400
+
401
+ # Benchmark optimized kernel
402
+ torch.cuda.synchronize()
403
+ grouped_gemm_forward(x, w, m_sizes) # Warmup
404
+ torch.cuda.synchronize()
405
+
406
+ logging.info("Benchmarking optimized kernel...")
407
+ torch.cuda.reset_peak_memory_stats()
408
+ start_time = time.time()
409
+ for _ in range(10):
410
+ grouped_gemm_forward(x, w, m_sizes)
411
+ torch.cuda.synchronize()
412
+ end_time = time.time()
413
+ opt_time = (end_time - start_time) / 10
414
+ opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
415
+
416
+ # Calculate FLOPs and speedup
417
+ flops = 2 * M * N * K
418
+ pt_tflops = flops / pt_time / 1e12
419
+ opt_tflops = flops / opt_time / 1e12
420
+ speedup = pt_time / opt_time
421
+
422
+ # Store results
423
+ results.append(
424
+ {
425
+ "config": f"Config {config_idx + 1}",
426
+ "dimensions": f"M={M}, K={K}, N={N}, G={G}",
427
+ "pt_time_ms": pt_time * 1000,
428
+ "opt_time_ms": opt_time * 1000,
429
+ "pt_tflops": pt_tflops,
430
+ "opt_tflops": opt_tflops,
431
+ "speedup": speedup,
432
+ "pt_memory_mb": pt_memory,
433
+ "opt_memory_mb": opt_memory,
434
+ "memory_savings": (
435
+ (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0
436
+ ),
437
+ }
438
+ )
439
+
440
+ logging.info(
441
+ f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB"
442
+ )
443
+ logging.info(
444
+ f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB"
445
+ )
446
+ logging.info(
447
+ f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%"
448
+ )
449
+
450
+ # Print summary table
451
+ logging.info("\n===== Benchmark Results Summary =====")
452
+ logging.info(
453
+ f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}"
454
+ )
455
+ logging.info(
456
+ f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | "
457
+ f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}"
458
+ )
459
+ logging.info("-" * 100)
460
+
461
+ for result in results:
462
+ logging.info(
463
+ f"{result['config']:<10} | "
464
+ f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | "
465
+ f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | "
466
+ f"{result['speedup']:<10.2f} | "
467
+ f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | "
468
+ f"{result['memory_savings']:<12.2f}%"
469
+ )
470
+
471
+ return results
472
+
473
+
474
+ def plot_benchmark_results(results):
475
+ """
476
+ Plot benchmark results as bar charts.
477
+ """
478
+ # Extract data
479
+ configs = [r["config"] for r in results]
480
+ pt_tflops = [r["pt_tflops"] for r in results]
481
+ opt_tflops = [r["opt_tflops"] for r in results]
482
+ speedups = [r["speedup"] for r in results]
483
+
484
+ # Create figure with subplots
485
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
486
+
487
+ # Plot TFLOPS comparison
488
+ x = np.arange(len(configs))
489
+ width = 0.35
490
+ ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference")
491
+ ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel")
492
+ ax1.set_xlabel("Model Configuration")
493
+ ax1.set_ylabel("TFLOPS")
494
+ ax1.set_title("Performance Comparison (Higher is Better)")
495
+ ax1.set_xticks(x)
496
+ ax1.set_xticklabels(configs)
497
+ ax1.legend()
498
+ ax1.grid(axis="y", linestyle="--", alpha=0.7)
499
+
500
+ # Plot speedup
501
+ ax2.bar(x, speedups, width=0.6, color="green")
502
+ ax2.set_xlabel("Model Configuration")
503
+ ax2.set_ylabel("Speedup (x)")
504
+ ax2.set_title("Speedup Factor (Higher is Better)")
505
+ ax2.set_xticks(x)
506
+ ax2.set_xticklabels(configs)
507
+ ax2.grid(axis="y", linestyle="--", alpha=0.7)
508
+
509
+ # Add speedup values on top of bars
510
+ for i, v in enumerate(speedups):
511
+ ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center")
512
+
513
+ plt.tight_layout()
514
+ plt.savefig("mg_grouped_gemm_benchmark_results.png")
515
+ logging.info(
516
+ "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'"
517
+ )
518
+
519
+
520
+ def compare_mg_implementations():
521
+ """
522
+ Combine the M*G and N*G benchmark results for comparison.
523
+ """
524
+ # Only run this if both NG and MG benchmarks have been run
525
+ try:
526
+ import pandas as pd
527
+
528
+ # Try to load previous benchmark results
529
+ mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv")
530
+ ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv")
531
+
532
+ # Create comparison plot
533
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
534
+
535
+ # Plot speedup comparison
536
+ configs = mg_results["config"].unique()
537
+ mg_speedups = mg_results.groupby("config")["speedup"].mean()
538
+ ng_speedups = ng_results.groupby("config")["speedup"].mean()
539
+
540
+ x = np.arange(len(configs))
541
+ width = 0.35
542
+
543
+ axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping")
544
+ axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping")
545
+ axes[0].set_xlabel("Model Configuration")
546
+ axes[0].set_ylabel("Speedup (x)")
547
+ axes[0].set_title("Speedup Comparison: M*G vs N*G")
548
+ axes[0].set_xticks(x)
549
+ axes[0].set_xticklabels(configs)
550
+ axes[0].legend()
551
+ axes[0].grid(axis="y", linestyle="--", alpha=0.7)
552
+
553
+ # Plot TFLOPS comparison for optimized kernels
554
+ mg_tflops = (
555
+ mg_results[mg_results["implementation"] == "optimized"]
556
+ .groupby("config")["tflops"]
557
+ .mean()
558
+ )
559
+ ng_tflops = (
560
+ ng_results[ng_results["implementation"] == "optimized"]
561
+ .groupby("config")["tflops"]
562
+ .mean()
563
+ )
564
+
565
+ axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping")
566
+ axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping")
567
+ axes[1].set_xlabel("Model Configuration")
568
+ axes[1].set_ylabel("TFLOPS")
569
+ axes[1].set_title("Performance Comparison: M*G vs N*G")
570
+ axes[1].set_xticks(x)
571
+ axes[1].set_xticklabels(configs)
572
+ axes[1].legend()
573
+ axes[1].grid(axis="y", linestyle="--", alpha=0.7)
574
+
575
+ plt.tight_layout()
576
+ plt.savefig("mg_vs_ng_comparison.png")
577
+ logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'")
578
+
579
+ except Exception as e:
580
+ logging.error(f"Could not create comparison plot: {e}")
581
+ logging.info(
582
+ "Run both M*G and N*G benchmarks first to generate comparison plots"
583
+ )
584
+
585
+
586
+ if __name__ == "__main__":
587
+ parser = argparse.ArgumentParser(
588
+ description="Benchmark M*G Grouped GEMM implementations"
589
+ )
590
+ parser.add_argument("--run-all", action="store_true", help="Run all benchmarks")
591
+ parser.add_argument(
592
+ "--triton-bench", action="store_true", help="Run Triton performance reports"
593
+ )
594
+ parser.add_argument(
595
+ "--model-configs", action="store_true", help="Benchmark model configurations"
596
+ )
597
+ parser.add_argument(
598
+ "--compare-mg-ng",
599
+ action="store_true",
600
+ help="Compare M*G and N*G implementations",
601
+ )
602
+ args = parser.parse_args()
603
+
604
+ # Check if CUDA is available
605
+ if not torch.cuda.is_available():
606
+ logging.error(
607
+ "CUDA is not available. This benchmark requires a CUDA-capable GPU."
608
+ )
609
+ exit(1)
610
+
611
+ if args.run_all or args.model_configs:
612
+ # Benchmark model configurations
613
+ logging.info("Running benchmark for model configurations...")
614
+ results = benchmark_model_configs()
615
+ plot_benchmark_results(results)
616
+
617
+ if args.run_all or args.triton_bench:
618
+ # Run Triton performance reports
619
+ logging.info("Running Triton performance reports...")
620
+ benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results")
621
+ benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results")
622
+ benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results")
623
+ logging.info(
624
+ "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory"
625
+ )
626
+
627
+ if args.run_all or args.compare_mg_ng:
628
+ # Compare M*G and N*G implementations
629
+ logging.info("Comparing M*G and N*G implementations...")
630
+ compare_mg_implementations()
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
16
+ )
17
+
18
+
19
+ def compute_reference_forward(x, w, m_sizes):
20
+ """
21
+ Compute reference forward pass using PyTorch operations.
22
+
23
+ Args:
24
+ x (torch.Tensor): Input tensor of shape (M, K)
25
+ w (torch.Tensor): Weight tensor of shape (N, K)
26
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
27
+
28
+ Returns:
29
+ torch.Tensor: Reference output tensor of shape (M, N)
30
+ """
31
+ result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device)
32
+
33
+ m_start = 0
34
+ for g in range(len(m_sizes)):
35
+ m_size = m_sizes[g].item()
36
+ if m_size > 0:
37
+ m_end = m_start + m_size
38
+
39
+ # Extract group input
40
+ x_g = x[m_start:m_end]
41
+
42
+ # Compute group output: y_g = x_g @ w.T
43
+ y_g = torch.matmul(x_g, w.T)
44
+
45
+ # Store result
46
+ result[m_start:m_end] = y_g
47
+
48
+ # Update start index
49
+ m_start = m_end
50
+
51
+ return result
52
+
53
+
54
+ def compute_reference_backward(x, w, m_sizes, grad_output):
55
+ """
56
+ Compute reference backward pass using PyTorch autograd.
57
+
58
+ Args:
59
+ x (torch.Tensor): Input tensor of shape (M, K)
60
+ w (torch.Tensor): Weight tensor of shape (N, K)
61
+ m_sizes (torch.Tensor): Group sizes tensor of shape (G)
62
+ grad_output (torch.Tensor): Gradient tensor of shape (M, N)
63
+
64
+ Returns:
65
+ tuple: (grad_x, grad_w) gradient tensors
66
+ """
67
+ # Create autograd-enabled copies
68
+ x_autograd = x.detach().clone().requires_grad_(True)
69
+ w_autograd = w.detach().clone().requires_grad_(True)
70
+
71
+ # Compute forward pass
72
+ output = compute_reference_forward(x_autograd, w_autograd, m_sizes)
73
+
74
+ # Backpropagate
75
+ output.backward(grad_output)
76
+
77
+ return x_autograd.grad, w_autograd.grad
78
+
79
+
80
+ def analyze_tensor_differences(actual, expected, name):
81
+ """
82
+ Analyze differences between actual and expected tensors.
83
+
84
+ Args:
85
+ actual (torch.Tensor): Actual tensor
86
+ expected (torch.Tensor): Expected tensor
87
+ name (str): Name of the tensor for logging
88
+
89
+ Returns:
90
+ bool: True if tensors are close enough
91
+ """
92
+ rtol = 0.5 # Relative tolerance for float16
93
+ atol = 0.5 # Absolute tolerance for float16
94
+
95
+ # Analyze differences
96
+ diff = (actual - expected).abs()
97
+ max_idx = diff.argmax().item()
98
+ idx = np.unravel_index(max_idx, actual.shape)
99
+ max_diff = diff.max().item()
100
+
101
+ logging.info(f"Largest {name} difference: {max_diff} at {idx}")
102
+ logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}")
103
+
104
+ is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol)
105
+
106
+ if is_close:
107
+ logging.info(f"✓ SUCCESS: {name} matches PyTorch reference")
108
+ else:
109
+ logging.error(f"✗ FAILURE: {name} mismatch detected")
110
+
111
+ # Count zeros
112
+ zeros_actual = (actual == 0).sum().item()
113
+ zeros_expected = (expected == 0).sum().item()
114
+ logging.info(
115
+ f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)"
116
+ )
117
+ logging.info(
118
+ f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)"
119
+ )
120
+
121
+ # Check for NaNs
122
+ nan_actual = torch.isnan(actual).sum().item()
123
+ if nan_actual > 0:
124
+ logging.error(f"NaN values detected in {name}: {nan_actual}")
125
+
126
+ return is_close
torchtitan/experiments/llama4/infra/parallelize_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+
15
+ from torchtitan.models.llama3.parallelize_llama import (
16
+ apply_ac,
17
+ apply_compile,
18
+ apply_ddp,
19
+ apply_fsdp,
20
+ apply_tp,
21
+ )
22
+ from torchtitan.tools.logging import logger
23
+
24
+
25
+ def parallelize_llama(
26
+ model: nn.Module,
27
+ world_mesh: DeviceMesh,
28
+ parallel_dims: ParallelDims,
29
+ job_config: JobConfig,
30
+ ):
31
+ """
32
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
33
+ parallelism to the model.
34
+
35
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
36
+ the model must fit on GPU or CPU memory.
37
+ """
38
+
39
+ if parallel_dims.tp_enabled:
40
+ if (
41
+ job_config.parallelism.enable_async_tensor_parallel
42
+ and not job_config.training.compile
43
+ ):
44
+ raise RuntimeError("Async TP requires --training.compile")
45
+
46
+ enable_float8_linear = "float8" in job_config.model.converters
47
+ float8_is_rowwise = job_config.float8.recipe_name in (
48
+ "rowwise",
49
+ "rowwise_with_gw_hp",
50
+ )
51
+
52
+ # For now, float8 all-gather with TP is only supported for tensorwise
53
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
54
+ # all-gather happens in high precision.
55
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
56
+
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
62
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
63
+ )
64
+
65
+ apply_moe_tp(model, world_mesh["tp"])
66
+
67
+ if job_config.activation_checkpoint.mode != "none":
68
+ if (
69
+ job_config.activation_checkpoint.mode == "selective"
70
+ and job_config.model.use_flex_attn
71
+ ):
72
+ raise ValueError(
73
+ "FlexAttention is not compatible with selective AC yet. "
74
+ "See https://github.com/pytorch/pytorch/issues/147879"
75
+ )
76
+ apply_ac(model, job_config.activation_checkpoint)
77
+
78
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
79
+ if job_config.training.compile:
80
+ apply_compile(model)
81
+
82
+ # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
+
85
+ if (
86
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
87
+ ): # apply FSDP or HSDP, potentially with Context Parallel
88
+ if parallel_dims.dp_replicate_enabled:
89
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
90
+ else:
91
+ dp_mesh_dim_names = ("dp_shard_cp",)
92
+
93
+ apply_fsdp(
94
+ model,
95
+ world_mesh[tuple(dp_mesh_dim_names)],
96
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
97
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
98
+ pp_enabled=parallel_dims.pp_enabled,
99
+ cpu_offload=job_config.training.enable_cpu_offload,
100
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
101
+ )
102
+
103
+ if parallel_dims.dp_replicate_enabled:
104
+ logger.info("Applied HSDP to the model")
105
+ else:
106
+ logger.info("Applied FSDP to the model")
107
+
108
+ if parallel_dims.cp_enabled:
109
+ logger.info("Applied Context Parallel to the model")
110
+
111
+ if job_config.training.enable_cpu_offload:
112
+ logger.info("Applied CPU Offloading to the model")
113
+ elif parallel_dims.dp_replicate_enabled:
114
+ if world_mesh.ndim > 1:
115
+ raise RuntimeError("DDP has not supported > 1D parallelism")
116
+ apply_ddp(
117
+ model,
118
+ world_mesh,
119
+ enable_compile=job_config.training.compile,
120
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121
+ )
122
+
123
+ return model
124
+
125
+
126
+ def apply_moe_tp(
127
+ model: nn.Module,
128
+ tp_mesh: DeviceMesh,
129
+ ):
130
+ from torch.distributed.tensor import Partial, Replicate, Shard
131
+ from torch.distributed.tensor.parallel import (
132
+ parallelize_module,
133
+ PrepareModuleInputOutput,
134
+ )
135
+
136
+ from .expert_parallel import NoParallel, TensorParallel
137
+
138
+ for _, transformer_block in model.layers.items():
139
+ moe_layer_plan = {
140
+ # input / output sharding on the seqlen dim
141
+ # all-gather for input, reduce-scatter for output
142
+ "moe": PrepareModuleInputOutput(
143
+ input_layouts=(Shard(1),),
144
+ desired_input_layouts=(Replicate(),),
145
+ use_local_input=True,
146
+ output_layouts=(Partial(),),
147
+ desired_output_layouts=(Shard(1),),
148
+ ),
149
+ # replicate computation for the router
150
+ "moe.router.gate": NoParallel(),
151
+ # input Replicate, output Partial
152
+ "moe.experts": TensorParallel(),
153
+ "moe.shared_expert": TensorParallel(),
154
+ }
155
+ parallelize_module(
156
+ module=transformer_block,
157
+ device_mesh=tp_mesh,
158
+ parallelize_plan=moe_layer_plan,
159
+ )
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
torchtitan/experiments/llama4/model/model.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from torchtitan.models.attention import build_attention, init_attention_mask
13
+ from torchtitan.models.norms import build_norm
14
+ from torchtitan.protocols.train_spec import ModelProtocol
15
+
16
+ from .args import TransformerModelArgs
17
+ from .moe import MoE
18
+
19
+
20
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
21
+ """
22
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
23
+
24
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
25
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
26
+ The returned tensor contains complex values in complex64 data type.
27
+
28
+ Args:
29
+ dim (int): Dimension of the frequency tensor.
30
+ end (int): End index for precomputing frequencies.
31
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
32
+
33
+ Returns:
34
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
35
+ """
36
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
37
+ t = torch.arange(end, device=freqs.device)
38
+ freqs = torch.outer(t, freqs).float()
39
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
40
+ return freqs_cis
41
+
42
+
43
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Reshape frequency tensor for broadcasting it with another tensor.
46
+
47
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
48
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
49
+
50
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
51
+ and the first seqlen elements will be sliced, but dim must match x.
52
+
53
+ Args:
54
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
55
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
56
+
57
+ Returns:
58
+ torch.Tensor: Reshaped frequency tensor.
59
+ """
60
+ ndim = x.ndim
61
+ assert ndim > 1
62
+ seqlen = x.shape[1]
63
+ freqs_cis = freqs_cis[0:seqlen]
64
+ assert freqs_cis.shape == (seqlen, x.shape[-1])
65
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
66
+ return freqs_cis.view(*shape)
67
+
68
+
69
+ def apply_rotary_emb(
70
+ xq: torch.Tensor,
71
+ xk: torch.Tensor,
72
+ freqs_cis: torch.Tensor,
73
+ ) -> tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Apply rotary embeddings to input tensors using the given frequency tensor.
76
+
77
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
78
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
79
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
80
+ returned as real tensors.
81
+
82
+ Args:
83
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
84
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
85
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
86
+
87
+ Returns:
88
+ tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
89
+ """
90
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
91
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
92
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
93
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
94
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
95
+ return xq_out.type_as(xq), xk_out.type_as(xk)
96
+
97
+
98
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
99
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
100
+ bs, slen, n_kv_heads, head_dim = x.shape
101
+ if n_rep == 1:
102
+ return x
103
+ return (
104
+ torch.unsqueeze(x, dim=3)
105
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
106
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
107
+ )
108
+
109
+
110
+ class Attention(nn.Module):
111
+ """
112
+ Multi-head attention module.
113
+
114
+ Args:
115
+ model_args (TransformerModelArgs): Model configuration arguments.
116
+
117
+ Attributes:
118
+ n_kv_heads (int): Number of key and value heads.
119
+ n_heads (int): Number of query heads.
120
+ n_rep (int): Number of repetitions for local heads.
121
+ head_dim (int): Dimension size of each attention head.
122
+ wq (Linear): Linear transformation for queries.
123
+ wk (Linear): Linear transformation for keys.
124
+ wv (Linear): Linear transformation for values.
125
+ wo (Linear): Linear transformation for output.
126
+
127
+ """
128
+
129
+ def __init__(self, model_args: TransformerModelArgs):
130
+ super().__init__()
131
+ self.n_heads = model_args.n_heads
132
+ self.n_kv_heads = (
133
+ model_args.n_heads
134
+ if model_args.n_kv_heads is None
135
+ else model_args.n_kv_heads
136
+ )
137
+ self.n_rep = self.n_heads // self.n_kv_heads
138
+ self.head_dim = model_args.dim // model_args.n_heads
139
+
140
+ self.wq = nn.Linear(
141
+ model_args.dim, model_args.n_heads * self.head_dim, bias=False
142
+ )
143
+ self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
144
+ self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
145
+ self.wo = nn.Linear(
146
+ model_args.n_heads * self.head_dim, model_args.dim, bias=False
147
+ )
148
+ self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
149
+
150
+ def init_weights(self, init_std: float):
151
+ for linear in (self.wq, self.wk, self.wv):
152
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
153
+ nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ freqs_cis: torch.Tensor,
159
+ ):
160
+ """
161
+ Forward pass of the attention module.
162
+
163
+ Args:
164
+ x (torch.Tensor): Input tensor.
165
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
166
+
167
+ Returns:
168
+ torch.Tensor: Output tensor after attention.
169
+
170
+ """
171
+
172
+ bs, seqlen, _ = x.shape
173
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
174
+
175
+ # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
176
+ # local heads from sizes of xq, xk, and xv as TP may have sharded them
177
+ # after the above linear ops.
178
+ xq = xq.view(bs, seqlen, -1, self.head_dim)
179
+ xk = xk.view(bs, seqlen, -1, self.head_dim)
180
+ xv = xv.view(bs, seqlen, -1, self.head_dim)
181
+
182
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
183
+
184
+ # repeat k/v heads if n_kv_heads < n_heads
185
+ keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
186
+ values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
187
+
188
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
+ xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
190
+ xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
191
+
192
+ output = self.sdpa(xq, xk, xv)
193
+
194
+ output = output.transpose(
195
+ 1, 2
196
+ ).contiguous() # (bs, seqlen, n_local_heads, head_dim)
197
+ output = output.view(bs, seqlen, -1)
198
+ return self.wo(output)
199
+
200
+
201
+ class FeedForward(nn.Module):
202
+ """
203
+ FeedForward module
204
+
205
+ Args:
206
+ dim (int): Input dimension.
207
+ hidden_dim (int): Hidden dimension of the feedforward layer.
208
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
209
+ ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
210
+
211
+ Attributes:
212
+ w1 (Linear): Linear transformation for the first layer.
213
+ w2 (Linear): Linear transformation for the second layer.
214
+ w3 (Linear): Linear transformation for the third layer.
215
+
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ dim: int,
221
+ hidden_dim: int,
222
+ multiple_of: int,
223
+ ffn_dim_multiplier: float | None,
224
+ ):
225
+ super().__init__()
226
+ hidden_dim = int(2 * hidden_dim / 3)
227
+ # custom dim factor multiplier
228
+ if ffn_dim_multiplier is not None:
229
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
230
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
231
+
232
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
233
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
234
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
235
+
236
+ def forward(self, x):
237
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
238
+
239
+ def init_weights(self, init_std: float):
240
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
241
+ for linear in (self.w2, self.w3):
242
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
243
+
244
+
245
+ class TransformerBlock(nn.Module):
246
+ """
247
+ TransformerBlock Module
248
+
249
+ Args:
250
+ layer_id (int): Identifier for the layer.
251
+ model_args (TransformerModelArgs): Model configuration arguments.
252
+
253
+ Attributes:
254
+ n_heads (int): Number of attention heads.
255
+ dim (int): Dimension size of the model.
256
+ head_dim (int): Dimension size of each attention head.
257
+ attention (Attention): Attention module.
258
+ feed_forward (FeedForward): FeedForward module.
259
+ layer_id (int): Identifier for the layer.
260
+ attention_norm (RMSNorm): Layer normalization for attention output.
261
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
262
+
263
+ """
264
+
265
+ def __init__(self, layer_id: int, model_args: TransformerModelArgs):
266
+ super().__init__()
267
+ self.n_heads = model_args.n_heads
268
+ self.dim = model_args.dim
269
+ self.attention = Attention(model_args)
270
+
271
+ # use MoE layer for every interleave_moe_layer_step FFN layers
272
+ self.moe_enabled = (
273
+ model_args.moe_enabled
274
+ and (layer_id + 1) % model_args.interleave_moe_layer_step == 0
275
+ )
276
+ if self.moe_enabled:
277
+ self.moe = MoE(model_args)
278
+ else:
279
+ self.feed_forward = FeedForward(
280
+ dim=model_args.dim,
281
+ hidden_dim=4 * model_args.dim,
282
+ multiple_of=model_args.multiple_of,
283
+ ffn_dim_multiplier=model_args.ffn_dim_multiplier,
284
+ )
285
+
286
+ self.layer_id = layer_id
287
+ self.num_layers = model_args.n_layers
288
+
289
+ self.attention_norm = build_norm(
290
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
291
+ )
292
+ self.ffn_norm = build_norm(
293
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
294
+ )
295
+
296
+ if model_args.depth_init:
297
+ self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
298
+ else:
299
+ self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
300
+
301
+ def forward(
302
+ self,
303
+ x: torch.Tensor,
304
+ freqs_cis: torch.Tensor,
305
+ ):
306
+ """
307
+ Perform a forward pass through the TransformerBlock.
308
+
309
+ Args:
310
+ x (torch.Tensor): Input tensor.
311
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
312
+
313
+ Returns:
314
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
315
+
316
+ """
317
+ h = x + self.attention(self.attention_norm(x), freqs_cis)
318
+ if self.moe_enabled:
319
+ out = h + self.moe(self.ffn_norm(h))
320
+ else:
321
+ out = h + self.feed_forward(self.ffn_norm(h))
322
+ return out
323
+
324
+ def init_weights(self):
325
+ for norm in (self.attention_norm, self.ffn_norm):
326
+ norm.reset_parameters()
327
+ self.attention.init_weights(self.weight_init_std)
328
+ if self.moe_enabled:
329
+ self.moe.init_weights(self.weight_init_std)
330
+ else:
331
+ self.feed_forward.init_weights(self.weight_init_std)
332
+
333
+
334
+ class Transformer(nn.Module, ModelProtocol):
335
+ """
336
+ Transformer Module
337
+
338
+ Args:
339
+ model_args (TransformerModelArgs): Model configuration arguments.
340
+
341
+ Attributes:
342
+ model_args (TransformerModelArgs): Model configuration arguments.
343
+ vocab_size (int): Vocabulary size.
344
+ n_layers (int): Number of layers in the model.
345
+ tok_embeddings (ParallelEmbedding): Token embeddings.
346
+ layers (torch.nn.ModuleList): List of Transformer blocks.
347
+ norm (RMSNorm): Layer normalization for the model output.
348
+ output (ColumnParallelLinear): Linear layer for final output.
349
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
350
+
351
+ """
352
+
353
+ def __init__(self, model_args: TransformerModelArgs):
354
+ super().__init__()
355
+ self.model_args = model_args
356
+ self.vocab_size = model_args.vocab_size
357
+ self.n_layers = model_args.n_layers
358
+ self.eos_id = model_args.eos_id
359
+
360
+ self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
361
+
362
+ # TODO persistent should be set to false, since this buffer can be recomputed.
363
+ # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
364
+ # compile or pipeline-tracer will not correctly handle non-persistent buffers,
365
+ # so we need to fix that. (2) if we initialize pipeline-parallel models from
366
+ # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
367
+ # initialized by the checkpoint, or we need to add a separate initializer for
368
+ # just the non-persistent buffers that is called after loading checkpoints.
369
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
370
+
371
+ self.layers = torch.nn.ModuleDict()
372
+ for layer_id in range(model_args.n_layers):
373
+ self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
374
+
375
+ self.norm = build_norm(
376
+ model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
377
+ )
378
+
379
+ self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
380
+ self.init_weights()
381
+
382
+ def init_weights(
383
+ self,
384
+ buffer_device: torch.device | None = None,
385
+ ):
386
+ """
387
+ [Note: On ``init_weights`` vs. ``reset_parameters``]
388
+ Modules may define ``reset_parameters`` to initialize parameter values.
389
+ ``reset_parameters`` is meant to only initialize directly owned
390
+ parameters/buffers, not those of their child modules, and it can be
391
+ used to give the initial values for these tensors.
392
+ Separately, users may want custom initialization for their modules,
393
+ different from that in ``reset_parameters``. For this, we define
394
+ ``init_weights``. We only call it in the constructor of this
395
+ ``Transformer`` root module to avoid reinitializing tensors.
396
+ """
397
+ buffer_device = buffer_device or self.freqs_cis.device
398
+ with torch.device(buffer_device):
399
+ self.freqs_cis = self._precompute_freqs_cis()
400
+ if self.tok_embeddings is not None:
401
+ nn.init.normal_(self.tok_embeddings.weight)
402
+ for layer in self.layers.values():
403
+ if layer is not None:
404
+ layer.init_weights()
405
+ if self.norm is not None:
406
+ self.norm.reset_parameters()
407
+ final_out_std = self.model_args.dim**-0.5
408
+ cutoff_factor = 3
409
+ if self.output is not None:
410
+ nn.init.trunc_normal_(
411
+ self.output.weight,
412
+ mean=0.0,
413
+ std=final_out_std,
414
+ a=-cutoff_factor * final_out_std,
415
+ b=cutoff_factor * final_out_std,
416
+ )
417
+
418
+ def _precompute_freqs_cis(self) -> torch.Tensor:
419
+ return precompute_freqs_cis(
420
+ self.model_args.dim // self.model_args.n_heads,
421
+ # Need to compute until at least the max token limit for generation
422
+ # TODO: explain in docs/composability.md why we removed the 2x
423
+ # relaxing in our CP enablement PR
424
+ self.model_args.max_seq_len,
425
+ self.model_args.rope_theta,
426
+ )
427
+
428
+ def forward(self, tokens: torch.Tensor):
429
+ """
430
+ Perform a forward pass through the Transformer model.
431
+
432
+ Args:
433
+ tokens (torch.Tensor): Input token indices.
434
+
435
+ Returns:
436
+ torch.Tensor: Output logits after applying the Transformer model.
437
+
438
+ """
439
+ # TODO: We will to change forward() signature to allow tokens to
440
+ # be always passed in.
441
+ if self.model_args.use_flex_attn:
442
+ init_attention_mask(tokens, eos_id=self.eos_id)
443
+
444
+ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
445
+ h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
446
+
447
+ for layer in self.layers.values():
448
+ h = layer(h, self.freqs_cis)
449
+
450
+ h = self.norm(h) if self.norm else h
451
+ output = self.output(h) if self.output else h
452
+ return output
453
+
454
+ @classmethod
455
+ def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
456
+ """
457
+ Initialize a Transformer model from a TransformerModelArgs object.
458
+
459
+ Args:
460
+ model_args (TransformerModelArgs): Model configuration arguments.
461
+
462
+ Returns:
463
+ Transformer: Transformer model.
464
+
465
+ """
466
+ return cls(model_args)
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Any, Optional
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard
16
+ from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
17
+ from torchtitan.components.checkpoint import MODEL
18
+ from torchtitan.config_manager import JobConfig
19
+ from torchtitan.tools.logging import init_logger, logger
20
+ from torchtitan.train import Trainer
21
+
22
+ # Sharding dims for MP checkpoints
23
+
24
+ column_parallel = [
25
+ "tok_embeddings",
26
+ "wq",
27
+ "wk",
28
+ "wv",
29
+ "wqkv",
30
+ "w_in_shared_FD",
31
+ "w_out_eF_D",
32
+ "w_swiglu_FD",
33
+ "output",
34
+ "_linear",
35
+ "c_fc",
36
+ "vision_projection",
37
+ ]
38
+
39
+ row_parallel = [
40
+ "wo",
41
+ "w_out_shared_DF",
42
+ "w_in_eD_F",
43
+ "moe_w_swiglu_eD_F",
44
+ "c_proj",
45
+ ]
46
+
47
+
48
+ def convert_to_titan_fqns(fqn: str) -> list[str]:
49
+ # From the stored checkpoint keys to TorchTitan keys.
50
+ if "wqkv" in fqn and "layer_norm_weight" not in fqn:
51
+ ret = []
52
+ for k in ("wq", "wk", "wv"):
53
+ ret.append(fqn.replace("wqkv", k))
54
+ return ret
55
+ return [fqn]
56
+
57
+
58
+ def get_shard_dim(fqn: str) -> Optional[int]:
59
+ if "bias" in fqn:
60
+ # Some bias params are still sharded
61
+ if "resblocks" in fqn:
62
+ for k in ("wq", "wk", "wv", "c_fc"):
63
+ if k in fqn:
64
+ return 0
65
+ return None
66
+ elif any([x in fqn for x in column_parallel]):
67
+ return 0
68
+ elif any([x in fqn for x in row_parallel]):
69
+ return 1
70
+ else:
71
+ return None
72
+
73
+
74
+ def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
75
+ qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards]
76
+ q = torch.cat([qkv[0] for qkv in qkvs], dim=0)
77
+ k = torch.cat([qkv[1] for qkv in qkvs], dim=0)
78
+ v = torch.cat([qkv[2] for qkv in qkvs], dim=0)
79
+ return q, k, v
80
+
81
+
82
+ @dataclass
83
+ class _Assignment:
84
+ loader_id: int
85
+ filename: str
86
+ fqns: tuple[str, ...]
87
+ shapes: tuple[torch.Size, ...]
88
+ dtypes: tuple[torch.dtype, ...]
89
+
90
+
91
+ @dataclass
92
+ class _AssignmentRound:
93
+ loader_assignments: dict[int, _Assignment] # List of assignments for each loader
94
+
95
+
96
+ class CheckpointConverter:
97
+ TOTAL_SHARDS = 8
98
+
99
+ def __init__(
100
+ self,
101
+ process_group: dist.ProcessGroup,
102
+ path: str,
103
+ loader_every_n_ranks: int = 8,
104
+ ) -> None:
105
+ self.path = path
106
+ self.pg = process_group
107
+ self.my_rank = dist.get_rank(self.pg)
108
+ self.loader_every_n_ranks = loader_every_n_ranks
109
+ self.loader_id = self.my_rank // loader_every_n_ranks
110
+ self.should_load = (
111
+ self.my_rank % loader_every_n_ranks == 0
112
+ and self.loader_id < CheckpointConverter.TOTAL_SHARDS
113
+ )
114
+ self.total_loader = CheckpointConverter.TOTAL_SHARDS
115
+ self.titan_fqn_to_stored_fqn: dict[str, str] = {}
116
+ self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {}
117
+ self.total_send_bytes = 0
118
+ self.total_recv_bytes = 0
119
+
120
+ def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
121
+ begin = time.time()
122
+ self._load_metadata()
123
+ self._create_fqn_mappings(state_dict)
124
+ rounds = self._get_load_assignments(state_dict)
125
+
126
+ for assignments in rounds:
127
+ loader_assignments = assignments.loader_assignments
128
+ loaded_state_dict = None
129
+ # Let each loader to load its own data and move to its GPU.
130
+ for i in range(self.total_loader):
131
+ # This loader doesn't have any loading assignment for this round.
132
+ if i not in loader_assignments:
133
+ continue
134
+ # This rank is not the loader
135
+ if i != self.loader_id or not self.should_load:
136
+ continue
137
+ loaded_state_dict = self._load_round(loader_assignments[i])
138
+
139
+ results = []
140
+ for i in range(self.total_loader):
141
+ if i not in loader_assignments:
142
+ continue
143
+
144
+ if i == self.loader_id and self.should_load:
145
+ # This rank is the loader. It needs to send the loaded data to
146
+ # the other ranks.
147
+ assert loaded_state_dict is not None
148
+ results.append(
149
+ self._reshard_send(loader_assignments[i], loaded_state_dict)
150
+ )
151
+ else:
152
+ results.append(
153
+ self._reshard_receive(loader_assignments[i], state_dict)
154
+ )
155
+
156
+ self._reshard(results, state_dict)
157
+
158
+ torch.cuda.synchronize()
159
+ logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.")
160
+ logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB")
161
+ logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB")
162
+ return state_dict
163
+
164
+ def _get_file_path(self, loader_id: int) -> str:
165
+ return os.path.join(self.path, f"consolidated.0{loader_id}.pth")
166
+
167
+ def _load_metadata(self) -> None:
168
+ if not self.should_load:
169
+ self.read_dict = {}
170
+ return
171
+ self.read_dict = torch.load(
172
+ self._get_file_path(self.loader_id),
173
+ mmap=True,
174
+ weights_only=False,
175
+ )
176
+
177
+ def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None:
178
+ if not self.read_dict:
179
+ return
180
+
181
+ # Create the mapping from the stored checkpoint keys to TorchTitan keys.
182
+ for fqn in list(self.read_dict.keys()):
183
+ titan_fqns = convert_to_titan_fqns(fqn)
184
+ # We don't know how to process _extra_state
185
+ if "_extra_state" in fqn:
186
+ self.read_dict.pop(fqn)
187
+ continue
188
+
189
+ if titan_fqns[0] not in state_dict:
190
+ for titan_fqn in titan_fqns:
191
+ assert titan_fqns[0] not in state_dict
192
+ self.read_dict.pop(fqn)
193
+ continue
194
+ self.stored_fqn_to_titan_fqn[fqn] = titan_fqns
195
+ for titan_fqn in titan_fqns:
196
+ self.titan_fqn_to_stored_fqn[titan_fqn] = fqn
197
+
198
+ assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), (
199
+ set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()),
200
+ set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()),
201
+ )
202
+
203
+ def _get_load_assignments(
204
+ self, state_dict: dict[str, torch.Tensor]
205
+ ) -> list[_AssignmentRound]:
206
+ if self.my_rank == 0:
207
+ rounds: list[_AssignmentRound] = []
208
+ size = 0
209
+ fqns = []
210
+ shapes = []
211
+ dtypes = []
212
+
213
+ # All loader must load all the FQNs because the checkpoint is purely TP sharded.
214
+ all_keys = list(self.read_dict.keys())
215
+ for fqn in all_keys:
216
+ fqns.append(fqn)
217
+ shapes.append(self.read_dict[fqn].shape)
218
+ dtypes.append(self.read_dict[fqn].dtype)
219
+ size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size()
220
+ if size < 1e9 and fqn != all_keys[-1]:
221
+ continue
222
+
223
+ logger.info(f"Adding {fqns} to round {len(rounds)}")
224
+ round_assignment = _AssignmentRound(loader_assignments={})
225
+ for loader_id in range(self.total_loader):
226
+ path = self._get_file_path(loader_id)
227
+ round_assignment.loader_assignments[loader_id] = _Assignment(
228
+ filename=path,
229
+ fqns=tuple(fqns),
230
+ shapes=tuple(shapes),
231
+ dtypes=tuple(dtypes),
232
+ loader_id=loader_id,
233
+ )
234
+ rounds.append(round_assignment)
235
+ size = 0
236
+ fqns.clear()
237
+ shapes.clear()
238
+ dtypes.clear()
239
+
240
+ object_list: list[Any] = [
241
+ rounds,
242
+ self.titan_fqn_to_stored_fqn,
243
+ self.stored_fqn_to_titan_fqn,
244
+ ]
245
+ else:
246
+ object_list = [None, None, None]
247
+
248
+ dist.broadcast_object_list(object_list, src=0, group=self.pg)
249
+ rounds = object_list[0]
250
+ self.titan_fqn_to_stored_fqn = object_list[1]
251
+ self.stored_fqn_to_titan_fqn = object_list[2]
252
+ return rounds
253
+
254
+ def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]:
255
+ ret = {}
256
+ assert self.read_dict
257
+ for fqn in assignment.fqns:
258
+ ret[fqn] = self.read_dict[fqn].to(device="cuda")
259
+ return ret
260
+
261
+ def _reshard_send(
262
+ self,
263
+ assignment: _Assignment,
264
+ loaded_state_dict: dict[str, torch.Tensor],
265
+ ) -> dict[str, torch.Tensor]:
266
+ flatten_tensors = [t.flatten() for t in loaded_state_dict.values()]
267
+ flatten_tensor = torch.concat(flatten_tensors)
268
+ assert self.loader_id == assignment.loader_id
269
+ rank = self.loader_id * self.loader_every_n_ranks
270
+ assert rank == self.my_rank
271
+ logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}")
272
+ logger.info(f"Sending {assignment.fqns}")
273
+ dist.broadcast(flatten_tensor, src=rank, group=self.pg)
274
+ self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
275
+ return loaded_state_dict
276
+
277
+ def _reshard_receive(
278
+ self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
279
+ ) -> dict[str, torch.Tensor]:
280
+ flatten_tensor = torch.empty(
281
+ sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
282
+ dtype=assignment.dtypes[0],
283
+ device="cuda",
284
+ )
285
+ rank = assignment.loader_id * self.loader_every_n_ranks
286
+ dist.broadcast(flatten_tensor, src=rank, group=self.pg)
287
+ self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
288
+
289
+ ret: dict[str, torch.Tensor] = {}
290
+ loc = 0
291
+ for fqn, shape, dtype in zip(
292
+ assignment.fqns, assignment.shapes, assignment.dtypes
293
+ ):
294
+ n_ele = math.prod(shape)
295
+ ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape)
296
+ loc += n_ele
297
+ return ret
298
+
299
+ def _reshard(
300
+ self,
301
+ results: list[dict[str, torch.Tensor]],
302
+ state_dict: dict[str, torch.Tensor],
303
+ ) -> None:
304
+ def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]):
305
+ titan_fqns = self.stored_fqn_to_titan_fqn[fqn]
306
+ assert len(titan_fqns) == len(full_tensors)
307
+ for titan_fqn, full_tensor in zip(titan_fqns, full_tensors):
308
+ dtensor = state_dict[titan_fqn]
309
+ logger.info(f"{titan_fqn} {full_tensor.sum()}")
310
+ assert isinstance(dtensor, DTensor)
311
+ shape, offset = compute_local_shape_and_global_offset(
312
+ full_tensor.shape, dtensor.device_mesh, dtensor.placements
313
+ )
314
+ slices = [
315
+ slice(cur_offset, cur_offset + cur_shape)
316
+ for cur_shape, cur_offset in zip(shape, offset)
317
+ ]
318
+ logger.info(
319
+ f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} "
320
+ f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} "
321
+ f"{dtensor.placements=} {dtensor.device_mesh=} "
322
+ )
323
+ dtensor.to_local().copy_(full_tensor[slices])
324
+
325
+ def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
326
+ if "wqkv" in fqn:
327
+ if "layer_norm" in fqn:
328
+ return (shards[0],)
329
+ return split_fused_qkv(shards)
330
+
331
+ shard_dim = get_shard_dim(fqn)
332
+ if shard_dim is None:
333
+ return (shards[0],)
334
+ return (torch.cat(shards, dim=shard_dim),)
335
+
336
+ fqns = list(results[0].keys())
337
+ for result in results:
338
+ assert list(result.keys()) == fqns
339
+
340
+ for fqn in fqns:
341
+ full_tensors = _concat_shards(fqn, [result[fqn] for result in results])
342
+ _inplace_copy(fqn, full_tensors)
343
+
344
+
345
+ def _create_verified_state_dict(
346
+ pg: dist.ProcessGroup, mesh: DeviceMesh
347
+ ) -> dict[str, torch.Tensor]:
348
+ placements = [Shard(0)]
349
+ state_dict = {
350
+ "tok_embeddings.weight": torch.rand(
351
+ 25256 * 8, 5120, device="cuda", dtype=torch.bfloat16
352
+ ),
353
+ "layers.47.attention.wqkv.layer_norm_weight": torch.rand(
354
+ 5120, device="cuda", dtype=torch.bfloat16
355
+ ),
356
+ "layers.47.attention.wq.weight": torch.rand(
357
+ 640 * 8, 5120, device="cuda", dtype=torch.bfloat16
358
+ ),
359
+ "layers.47.attention.wk.weight": torch.rand(
360
+ 128 * 8, 5120, device="cuda", dtype=torch.bfloat16
361
+ ),
362
+ "layers.47.attention.wv.weight": torch.rand(
363
+ 128 * 8, 5120, device="cuda", dtype=torch.bfloat16
364
+ ),
365
+ "layers.47.attention.wo.weight": torch.rand(
366
+ 5120, 640 * 8, device="cuda", dtype=torch.bfloat16
367
+ ),
368
+ # "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16),
369
+ # "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
370
+ # "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
371
+ "layers.47.feed_forward.w_in_shared_FD.weight": torch.rand(
372
+ 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
373
+ ),
374
+ "layers.47.feed_forward.w_out_shared_DF.weight": torch.rand(
375
+ 5120, 1024 * 8, device="cuda", dtype=torch.bfloat16
376
+ ),
377
+ "layers.47.feed_forward.w_swiglu_FD.weight": torch.rand(
378
+ 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
379
+ ),
380
+ "layers.47.feed_forward.norm.weight": torch.rand(
381
+ 5120, device="cuda", dtype=torch.bfloat16
382
+ ),
383
+ "layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand(
384
+ 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
385
+ ),
386
+ "layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand(
387
+ 131072 * 8, 5120, device="cuda", dtype=torch.bfloat16
388
+ ),
389
+ "layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand(
390
+ 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
391
+ ),
392
+ }
393
+ return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()}
394
+
395
+
396
+ def _verify_state_dict(
397
+ state_dict: dict[str, torch.Tensor], path: str, rank: int
398
+ ) -> None:
399
+ stored_state_dicts = [
400
+ torch.load(
401
+ os.path.join(path, f"consolidated.0{i}.pth"),
402
+ map_location="cpu",
403
+ weights_only=False,
404
+ mmap=True,
405
+ )
406
+ for i in range(8)
407
+ ]
408
+
409
+ def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None:
410
+ logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ")
411
+ shards = [stored_state_dicts[i][fqn] for i in range(8)]
412
+ full_tensor = dtensor.full_tensor()
413
+ logger.info(f"Gather {fqn} {full_tensor.shape} completely.")
414
+
415
+ if rank > 0:
416
+ return
417
+
418
+ if len(shards[0].shape) == 1:
419
+ assert full_tensor.shape == shards[0].shape, fqn
420
+ assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn
421
+ return
422
+ elif shards[0].shape[0] == full_tensor.shape[0]:
423
+ concat_shards = torch.cat(shards, dim=1)
424
+ logger.info(f"Load {fqn} completely.")
425
+ elif shards[0].shape[1] == full_tensor.shape[1]:
426
+ concat_shards = torch.cat(shards, dim=0)
427
+ logger.info(f"Load {fqn} completely.")
428
+
429
+ concat_shards = concat_shards.to(device="cuda")
430
+ logger.info(f"Move to GPU {fqn} completely.")
431
+
432
+ assert concat_shards.shape == full_tensor.shape, fqn
433
+ assert concat_shards.dtype == full_tensor.dtype, fqn
434
+ assert concat_shards.device == full_tensor.device, fqn
435
+ assert torch.allclose(concat_shards, full_tensor), fqn
436
+
437
+ for k, v in state_dict.items():
438
+ if "wq" in k and "wqkv" not in k:
439
+ pass
440
+ elif "wk" in k:
441
+ pass
442
+ elif "wv" in k:
443
+ pass
444
+ else:
445
+ assert v is not None, k
446
+ read_and_verify_tensor(k, v)
447
+
448
+
449
+ if __name__ == "__main__":
450
+ init_logger()
451
+ config = JobConfig()
452
+ config.parser.add_argument(
453
+ "--checkpoint.convert_path",
454
+ type=str,
455
+ default="",
456
+ help="""Specify the path of the target checkpoint to convert.""",
457
+ )
458
+ config.parser.add_argument(
459
+ "--checkpoint.convert_load_every_n_ranks",
460
+ type=int,
461
+ default=8,
462
+ help="""
463
+ Specify the interval at which ranks are assigned to load checkpoints.
464
+
465
+ For example, if this number is 4, then ranks 0, 4, 8, ... will load the
466
+ checkpoint. Each loader is responsible for loading one file. If there
467
+ are more loaders than files, only the first few loaders will be assigned
468
+ to load the checkpoint. The default value is 8.
469
+ """,
470
+ )
471
+ config.parser.add_argument(
472
+ "--checkpoint.fake_model",
473
+ action="store_true",
474
+ help="""If true, the model will be fake.""",
475
+ )
476
+ config.parse_args()
477
+ assert config.checkpoint.convert_path != ""
478
+
479
+ trainer: Optional[Trainer] = None
480
+
481
+ try:
482
+ trainer = Trainer(config)
483
+ if os.path.exists(trainer.checkpointer.folder):
484
+ raise RuntimeError(
485
+ "The checkpoint folder already exists. Abort to avoid overwriting "
486
+ f"the checkpoint. {trainer.checkpointer.folder=}"
487
+ )
488
+ if config.checkpoint.fake_model:
489
+ state_dict = _create_verified_state_dict(
490
+ trainer.world_mesh.get_group(), trainer.world_mesh
491
+ )
492
+ else:
493
+ state_dict = trainer.checkpointer.states[MODEL].state_dict()
494
+
495
+ size = 0
496
+ for v in state_dict.values():
497
+ size += v.numel() * v.element_size()
498
+ logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
499
+
500
+ # Do not support PP yet, we will need to iterate over the PP dimension and
501
+ # extract the corresponding state_dict and device_mesh.
502
+ if "freq_cis" in state_dict:
503
+ state_dict.pop("freqs_cis")
504
+
505
+ state_dict = CheckpointConverter(
506
+ process_group=trainer.world_mesh.get_group(),
507
+ path=config.checkpoint.convert_path,
508
+ loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks,
509
+ ).convert(state_dict)
510
+
511
+ class DummyModel:
512
+ def __init__(self, state_dict: dict[str, torch.Tensor]) -> None:
513
+ self._state_dict = state_dict
514
+
515
+ def state_dict(self) -> dict[str, torch.Tensor]:
516
+ return self._state_dict
517
+
518
+ if config.checkpoint.fake_model:
519
+ begin = time.time()
520
+ _verify_state_dict(
521
+ state_dict,
522
+ config.checkpoint.convert_path,
523
+ trainer.world_mesh.get_rank(),
524
+ )
525
+ dist.barrier()
526
+ logger.info(f"Verifies state_dict {time.time() - begin}.")
527
+ else:
528
+ # oh, this is pretty bad, when can we get rid of the freqs_cis issue?
529
+ state_dict["freqs_cis"] = None
530
+ trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
531
+ trainer.checkpointer.model_weights_only = True
532
+ trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
533
+ trainer.checkpointer.save(curr_step=0, force=True)
534
+ time.sleep(2)
535
+ finally:
536
+ pass
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: this toml config is still under development
2
+
3
+ [job]
4
+ dump_folder = "./outputs"
5
+ description = "Llama 4 Maverick 17Bx128E training"
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 100
11
+
12
+ [metrics]
13
+ log_freq = 10
14
+ enable_tensorboard = false
15
+ save_tb_folder = "tb"
16
+
17
+ [model]
18
+ name = "llama4"
19
+ flavor = "17bx128e"
20
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
21
+ tokenizer_path = "./assets/tokenizer/tokenizer.model"
22
+ # converters = "float8"
23
+
24
+ [optimizer]
25
+ name = "AdamW"
26
+ lr = 4e-3
27
+ eps = 1e-15
28
+
29
+ [lr_scheduler]
30
+ warmup_steps = 600
31
+ lr_min = 0.1
32
+
33
+ [training]
34
+ batch_size = 1
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8
45
+ enable_async_tensor_parallel = false
46
+ pipeline_parallel_degree = 4
47
+ # pipeline_parallel_schedule = "interleaved1f1b"
48
+ # pipeline_parallel_microbatches = 2
49
+ context_parallel_degree = 1
50
+
51
+ [checkpoint]
52
+ enable_checkpoint = false
53
+ folder = "checkpoint"
54
+ interval = 500
55
+ model_weights_only = false
56
+ export_dtype = "float32"
57
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
58
+
59
+ [activation_checkpoint]
60
+ mode = 'full' # ['none', 'selective', 'full']
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+ filter_fqns = "output,router.gate"
torchtitan/experiments/multimodal/mm_dataset.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from datasets import Dataset, load_dataset
13
+ from datasets.distributed import split_dataset_by_node
14
+
15
+ from mm_collator import MultiModalCollator
16
+ from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer
17
+ from torch.distributed.checkpoint.stateful import Stateful
18
+ from torch.utils.data import IterableDataset
19
+ from transform import CLIPTransform
20
+ from utils import load_image
21
+
22
+ from torchtitan.components.dataloader import ParallelAwareDataloader
23
+ from torchtitan.config_manager import JobConfig
24
+ from torchtitan.tools.logging import logger
25
+
26
+
27
+ def _load_obelics_dataset(dataset_path: str):
28
+ """Load C4 dataset with default configuration."""
29
+ return load_dataset(dataset_path, split="train", streaming=True)
30
+
31
+
32
+ def _process_obelics_sample(
33
+ sample: dict[str, Any], image_token: str = "<|image|>"
34
+ ) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]:
35
+ """
36
+ This function formats samples from the OBELICS dataset
37
+ Returns:
38
+ Dict[str, Any]: The transformed sample with the following fields:
39
+ - images: List[PIL.Image.Image] with the loaded images
40
+ - text: str with the text of the sample ready to be tokenized including the image tokens
41
+ Example:
42
+ >>> formatted_sample = format_obelics(sample, image_token="<|image|>")
43
+ >>> print(formatted_sample["text"])
44
+ ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :("
45
+ """
46
+ sample_images = [image for image in sample["images"] if image is not None]
47
+ sample_text = [
48
+ text if text is not None else image_token for text in sample["texts"]
49
+ ]
50
+ return {
51
+ "images": [load_image(image) for image in sample_images],
52
+ "text": "".join(map(str, sample_text)),
53
+ }
54
+
55
+
56
+ @dataclass
57
+ class DatasetConfig:
58
+ path: str
59
+ loader: Callable
60
+ sample_processor: Callable
61
+
62
+
63
+ # Add your dataset here here - more information at docs/datasets.md
64
+ MM_DATASETS = {
65
+ "obelics": DatasetConfig(
66
+ path="HuggingFaceM4/OBELICS",
67
+ loader=_load_obelics_dataset,
68
+ sample_processor=_process_obelics_sample,
69
+ ),
70
+ }
71
+
72
+
73
+ def _validate_mm_dataset(
74
+ dataset_name: str, dataset_path: str = None
75
+ ) -> tuple[str, Callable, Callable]:
76
+ """Validate dataset name and path."""
77
+ if dataset_name not in MM_DATASETS:
78
+ raise ValueError(
79
+ f"Dataset {dataset_name} is not supported. "
80
+ f"Supported datasets are: {list(MM_DATASETS.keys())}"
81
+ )
82
+
83
+ config = MM_DATASETS[dataset_name]
84
+ path = dataset_path or config.path
85
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
86
+ return path, config.loader, config.sample_processor
87
+
88
+
89
+ class MultiModalDataset(IterableDataset, Stateful):
90
+ """PyTorch MultiModal Dataset.
91
+
92
+ Args:
93
+ dataset_name (str): name of the dataset to load
94
+ tokenizer (Tokenizer):
95
+ Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
96
+ world_size (int): number of data parallel processes participating in training
97
+ rank (int): rank of the current data parallel process
98
+ infinite (bool): whether to loop infinitely over the dataset
99
+
100
+ We currently ONLY support the OBELICS dataset
101
+
102
+ Example use:
103
+ >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer)
104
+ >>> for batch in Dataloader(ds, batch_size=8):
105
+ print(f"Batch size: {len(batch)}")
106
+ Batch size: 8
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ dataset_name: str,
112
+ dataset_path: Optional[str],
113
+ tokenizer: Tokenizer,
114
+ image_token: str = "<|image|>",
115
+ tile_size: int = 448,
116
+ max_num_tiles: int = 4,
117
+ seq_len: int = 2048,
118
+ dp_rank: int = 0,
119
+ dp_world_size: int = 1,
120
+ infinite: bool = False,
121
+ ) -> None:
122
+ # Force lowercase for consistent comparison
123
+ dataset_name = dataset_name.lower()
124
+
125
+ path, dataset_loader, sample_processor = _validate_mm_dataset(
126
+ dataset_name, dataset_path
127
+ )
128
+ ds = dataset_loader(path)
129
+
130
+ # TODO: support shuffling
131
+ self.dataset_name = dataset_name
132
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
133
+ self._tokenizer = tokenizer
134
+ self.seq_len = seq_len
135
+ self.infinite = infinite
136
+ self._sample_processor = sample_processor
137
+ self.image_token = (
138
+ image_token # TODO(tj.solergibert) Add `image_token` to JobConfig
139
+ )
140
+ # TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig
141
+ self.transform_image = CLIPTransform(
142
+ image_mean=(
143
+ 0.48145466,
144
+ 0.4578275,
145
+ 0.40821073,
146
+ ), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?,
147
+ image_std=(0.26862954, 0.26130258, 0.27577711),
148
+ tile_size=tile_size,
149
+ possible_resolutions=None,
150
+ max_num_tiles=max_num_tiles,
151
+ resample="bilinear",
152
+ resize_to_max_canvas=False,
153
+ )
154
+
155
+ # variables for checkpointing
156
+ self._sample_idx = 0
157
+
158
+ def __iter__(self):
159
+
160
+ while True:
161
+ for sample in self._get_data_iter():
162
+ try:
163
+ sample = self._sample_processor(
164
+ sample, image_token=self.image_token
165
+ )
166
+ except Exception:
167
+ continue
168
+ self._sample_idx += 1
169
+
170
+ # CLIP Transform
171
+ encoder_input = {"images": [], "aspect_ratio": []}
172
+ for image in sample["images"]:
173
+ out = self.transform_image(image)
174
+ encoder_input["images"].append(out["image"])
175
+ encoder_input["aspect_ratio"].append(out["aspect_ratio"])
176
+ sample["encoder_input"] = encoder_input
177
+
178
+ # Tokenize
179
+ tokens = self._tokenizer.encode(
180
+ sample["text"],
181
+ bos=True,
182
+ eos=True,
183
+ allowed_special=set(["<|image|>"]),
184
+ )
185
+ sample["input_ids"] = torch.LongTensor(tokens[:-1])
186
+ sample["labels"] = torch.LongTensor(tokens[1:])
187
+ # Mask BOS, EOS & image tokens from the loss
188
+ sample["labels"] = torch.where(
189
+ torch.isin(
190
+ sample["labels"],
191
+ torch.LongTensor(
192
+ [
193
+ self._tokenizer.bos_id,
194
+ self._tokenizer.eos_id,
195
+ self._tokenizer.image_id,
196
+ ]
197
+ ),
198
+ ),
199
+ IGNORE_INDEX,
200
+ sample["labels"],
201
+ )
202
+ # Truncate
203
+ sample["input_ids"], sample["labels"] = (
204
+ sample["input_ids"][: self.seq_len],
205
+ sample["labels"][: self.seq_len],
206
+ )
207
+ yield sample
208
+
209
+ if not self.infinite:
210
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
211
+ break
212
+ else:
213
+ # Reset offset for the next iteration
214
+ self._sample_idx = 0
215
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
216
+
217
+ def _get_data_iter(self):
218
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
219
+ return iter([])
220
+
221
+ it = iter(self._data)
222
+ for _ in range(self._sample_idx):
223
+ next(it)
224
+ return it
225
+
226
+ def load_state_dict(self, state_dict):
227
+ self._sample_idx = state_dict["sample_idx"]
228
+
229
+ def state_dict(self):
230
+ return {"sample_idx": self._sample_idx}
231
+
232
+
233
+ def build_mm_dataloader(
234
+ dp_world_size: int,
235
+ dp_rank: int,
236
+ tokenizer: Tokenizer,
237
+ job_config: JobConfig,
238
+ infinite: bool = True,
239
+ ) -> ParallelAwareDataloader:
240
+ """Build a data loader for HuggingFace datasets."""
241
+ dataset_name = job_config.training.dataset
242
+ dataset_path = job_config.training.dataset_path
243
+ batch_size = job_config.training.batch_size
244
+ seq_len = job_config.training.seq_len
245
+ pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig
246
+ padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig
247
+
248
+ hf_ds = MultiModalDataset(
249
+ dataset_name=dataset_name,
250
+ dataset_path=dataset_path,
251
+ tokenizer=tokenizer,
252
+ seq_len=seq_len,
253
+ dp_rank=dp_rank,
254
+ dp_world_size=dp_world_size,
255
+ infinite=infinite,
256
+ )
257
+
258
+ collate_fn = MultiModalCollator(
259
+ padding_idx=padding_idx, pad_max_tiles=pad_max_tiles
260
+ )
261
+
262
+ return ParallelAwareDataloader(
263
+ dataset=hf_ds,
264
+ dp_rank=dp_rank,
265
+ dp_world_size=dp_world_size,
266
+ batch_size=batch_size,
267
+ collate_fn=collate_fn,
268
+ )
torchtitan/experiments/multimodal/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torchvision
torchtitan/experiments/multimodal/tests/test_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ def fixed_init_tensor(
16
+ shape: torch.Size,
17
+ min_val: Union[float, int] = 0.0,
18
+ max_val: Union[float, int] = 1.0,
19
+ nonlinear: bool = False,
20
+ dtype: torch.dtype = torch.float,
21
+ ):
22
+ """
23
+ Utility for generating deterministic tensors of a given shape. In general stuff
24
+ like torch.ones, torch.eye, etc can result in trivial outputs. This utility
25
+ generates a range tensor [min_val, max_val) of a specified dtype, applies
26
+ a sine function if nonlinear=True, then reshapes to the appropriate shape.
27
+ """
28
+ n_elements = math.prod(shape)
29
+ step_size = (max_val - min_val) / n_elements
30
+ x = torch.arange(min_val, max_val, step_size, dtype=dtype)
31
+ x = x.reshape(shape)
32
+ if nonlinear:
33
+ return torch.sin(x)
34
+ return x
35
+
36
+
37
+ @torch.no_grad
38
+ def fixed_init_model(
39
+ model: nn.Module,
40
+ min_val: Union[float, int] = 0.0,
41
+ max_val: Union[float, int] = 1.0,
42
+ nonlinear: bool = False,
43
+ dtype: Optional[torch.dtype] = None,
44
+ ):
45
+ """
46
+ This utility initializes all parameters of a model deterministically using the
47
+ function fixed_init_tensor above. See that docstring for details of each parameter.
48
+ """
49
+ for _, param in model.named_parameters():
50
+ param.copy_(
51
+ fixed_init_tensor(
52
+ param.shape,
53
+ min_val=min_val,
54
+ max_val=max_val,
55
+ nonlinear=nonlinear,
56
+ dtype=param.dtype if dtype is None else dtype,
57
+ )
58
+ )
torchtitan/experiments/multimodal/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from pathlib import Path
12
+ from typing import (
13
+ AbstractSet,
14
+ Any,
15
+ cast,
16
+ Collection,
17
+ Dict,
18
+ Iterator,
19
+ List,
20
+ Literal,
21
+ Mapping,
22
+ Optional,
23
+ Sequence,
24
+ Union,
25
+ )
26
+
27
+ import tiktoken
28
+ import torch
29
+ from tiktoken.load import load_tiktoken_bpe
30
+
31
+ from torchtitan.components.tokenizer import Tokenizer
32
+ from torchtitan.config_manager import JobConfig
33
+ from torchtitan.tools.logging import logger
34
+
35
+ IMAGE_TOKEN_ID = 128256
36
+ IGNORE_INDEX = -100
37
+
38
+
39
+ class TikTokenizer(Tokenizer):
40
+ """
41
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
42
+
43
+ Args:
44
+ model_path (str): The path to the Tiktoken model file.
45
+ """
46
+
47
+ special_tokens: Dict[str, int]
48
+
49
+ num_reserved_special_tokens = 256
50
+
51
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
52
+
53
+ def __init__(self, model_path: str):
54
+ super().__init__(model_path)
55
+ assert os.path.isfile(model_path), model_path
56
+
57
+ mergeable_ranks = load_tiktoken_bpe(model_path)
58
+ num_base_tokens = len(mergeable_ranks)
59
+ special_tokens = [
60
+ "<|begin_of_text|>",
61
+ "<|end_of_text|>",
62
+ "<|reserved_special_token_0|>",
63
+ "<|reserved_special_token_1|>",
64
+ "<|reserved_special_token_2|>",
65
+ "<|reserved_special_token_3|>",
66
+ "<|start_header_id|>",
67
+ "<|end_header_id|>",
68
+ "<|reserved_special_token_4|>",
69
+ "<|eot_id|>", # end of turn
70
+ ] + [
71
+ f"<|reserved_special_token_{i}|>"
72
+ for i in range(5, self.num_reserved_special_tokens - 5)
73
+ ]
74
+ self.special_tokens = {
75
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
76
+ }
77
+ self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID
78
+ self.model = tiktoken.Encoding(
79
+ name=Path(model_path).name,
80
+ pat_str=self.pat_str,
81
+ mergeable_ranks=mergeable_ranks,
82
+ special_tokens=self.special_tokens,
83
+ )
84
+
85
+ self._n_words: int = self.model.n_vocab
86
+ # BOS / EOS token IDs
87
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
88
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
89
+ self.pad_id: int = -1
90
+ self.image_id = IMAGE_TOKEN_ID
91
+ self.stop_tokens = {
92
+ self.special_tokens["<|end_of_text|>"],
93
+ self.special_tokens["<|eot_id|>"],
94
+ }
95
+ logger.info(
96
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}"
97
+ )
98
+
99
+ def encode(
100
+ self,
101
+ s: str,
102
+ *,
103
+ bos: bool,
104
+ eos: bool,
105
+ allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
106
+ disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None,
107
+ ) -> List[int]:
108
+ """
109
+ Encodes a string into a list of token IDs.
110
+
111
+ Args:
112
+ s (str): The input string to be encoded.
113
+ bos (bool): Whether to prepend the beginning-of-sequence token.
114
+ eos (bool): Whether to append the end-of-sequence token.
115
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
116
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
117
+
118
+ Returns:
119
+ list[int]: A list of token IDs.
120
+
121
+ By default, setting disallowed_special=() encodes a string by ignoring
122
+ special tokens. Specifically:
123
+ - Setting `disallowed_special` to () will cause all text corresponding
124
+ to special tokens to be encoded as natural text (insteading of raising
125
+ an error).
126
+ - Setting `allowed_special` to "all" will treat all text corresponding
127
+ to special tokens to be encoded as special tokens.
128
+ """
129
+ assert type(s) is str
130
+ allowed_special = allowed_special or set()
131
+ disallowed_special = disallowed_special or ()
132
+
133
+ # The tiktoken tokenizer can handle <=400k chars without
134
+ # pyo3_runtime.PanicException.
135
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
136
+
137
+ # https://github.com/openai/tiktoken/issues/195
138
+ # Here we iterate over subsequences and split if we exceed the limit
139
+ # of max consecutive non-whitespace or whitespace characters.
140
+ MAX_NO_WHITESPACES_CHARS = 25_000
141
+
142
+ substrs = (
143
+ substr
144
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
145
+ for substr in self._split_whitespaces_or_nonwhitespaces(
146
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
147
+ )
148
+ )
149
+ t: List[int] = []
150
+ for substr in substrs:
151
+ t.extend(
152
+ self.model.encode(
153
+ substr,
154
+ allowed_special=allowed_special,
155
+ disallowed_special=disallowed_special,
156
+ )
157
+ )
158
+ if bos:
159
+ t.insert(0, self.bos_id)
160
+ if eos:
161
+ t.append(self.eos_id)
162
+ return t
163
+
164
+ def decode(self, t: Sequence[int]) -> str:
165
+ """
166
+ Decodes a list of token IDs into a string.
167
+
168
+ Args:
169
+ t (List[int]): The list of token IDs to be decoded.
170
+
171
+ Returns:
172
+ str: The decoded string.
173
+ """
174
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
175
+ return self.model.decode(cast(List[int], t))
176
+
177
+ @staticmethod
178
+ def _split_whitespaces_or_nonwhitespaces(
179
+ s: str, max_consecutive_slice_len: int
180
+ ) -> Iterator[str]:
181
+ """
182
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
183
+ consecutive whitespaces or consecutive non-whitespaces.
184
+ """
185
+ current_slice_len = 0
186
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
187
+ slice_start = 0
188
+
189
+ for i in range(len(s)):
190
+ is_now_space = s[i].isspace()
191
+
192
+ if current_slice_is_space ^ is_now_space:
193
+ current_slice_len = 1
194
+ current_slice_is_space = is_now_space
195
+ else:
196
+ current_slice_len += 1
197
+ if current_slice_len > max_consecutive_slice_len:
198
+ yield s[slice_start:i]
199
+ slice_start = i
200
+ current_slice_len = 1
201
+ yield s[slice_start:]
202
+
203
+ def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]:
204
+ """
205
+ Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens.
206
+ """
207
+ # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator?
208
+ # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder`
209
+ # & everything else expects `tokens`
210
+ text = sample["text"]
211
+ tokens = self.encode(
212
+ text, bos=True, eos=True, allowed_special=set(["<|image|>"])
213
+ )
214
+ input_ids = torch.LongTensor(tokens[:-1])
215
+ labels = torch.LongTensor(tokens[1:])
216
+ labels = torch.where(
217
+ torch.isin(
218
+ labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id])
219
+ ),
220
+ IGNORE_INDEX,
221
+ labels,
222
+ )
223
+
224
+ assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete
225
+
226
+ sample.update({"tokens": input_ids, "labels": labels})
227
+
228
+ return sample
229
+
230
+
231
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
232
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/experiments/multimodal/utils.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ from collections import defaultdict
10
+
11
+ from pathlib import Path
12
+ from typing import List, Optional, Set, Tuple, Union
13
+ from urllib import request
14
+
15
+ import torch
16
+ import torchvision
17
+ from torchvision.transforms.v2 import functional as F
18
+
19
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py
20
+ def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor:
21
+ """
22
+ Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size.
23
+
24
+ Args:
25
+ image (torch.Tensor): Input image to crop into tiles.
26
+ tile_size (int): Size of each tile.
27
+
28
+ Returns:
29
+ torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size]
30
+
31
+ Examples:
32
+ >>> image = torch.rand(3, 200, 300)
33
+ >>> tiles = tile_crop(image, tile_size=50)
34
+ >>> tiles.shape # 4x6 = 24 tiles
35
+ torch.Size([24, 3, 50, 50])
36
+
37
+ >>> image = torch.rand(3, 400, 600)
38
+ >>> tiles = tile_crop(image, tile_size=200)
39
+ >>> tiles.shape # 2x3 = 6 tiles
40
+ torch.Size([6, 3, 200, 200])
41
+ """
42
+
43
+ channel_size, height, width = image.shape
44
+
45
+ # assert sizes are divisible
46
+ assert (
47
+ height % tile_size == 0 and width % tile_size == 0
48
+ ), f"Image size {height}x{width} is not divisible by tile size {tile_size}"
49
+
50
+ # Reshape to split height and width into tile_size blocks
51
+ tiles_height = height // tile_size
52
+ tiles_width = width // tile_size
53
+
54
+ reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size)
55
+
56
+ # Transpose to bring tiles together
57
+ # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size]
58
+ transposed = reshaped.permute(1, 3, 0, 2, 4)
59
+
60
+ # Flatten the tiles
61
+ tiles = transposed.contiguous().view(
62
+ tiles_height * tiles_width, channel_size, tile_size, tile_size
63
+ )
64
+
65
+ return tiles
66
+
67
+
68
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
69
+ def resize_with_pad(
70
+ image: torch.Tensor,
71
+ target_size: Tuple[int, int],
72
+ resample: torchvision.transforms.InterpolationMode,
73
+ max_size: Optional[int] = None,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Resizes and pads an image to target_size without causing distortion.
77
+ The user can set max_size to limit upscaling when target_size exceeds image_size.
78
+
79
+ Args:
80
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
81
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
82
+ resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images.
83
+ Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT,
84
+ InterpolationMode.BILINEAR and InterpolationMode.BICUBIC.
85
+ max_size (Optional[int]): The maximum size to upscale the image to.
86
+ If None, will upscale up to target_size.
87
+
88
+ Returns:
89
+ torch.Tensor: The resized and padded image tensor in the format [..., H, W].
90
+
91
+ Examples:
92
+
93
+ Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side,
94
+ and then padded from (448, 1194) to (448, 1344).
95
+
96
+ >>> max_size = None
97
+ >>> image = torch.rand([3, 300, 800])
98
+ >>> target_size = (448, 1344)
99
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
100
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
101
+
102
+ Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344).
103
+
104
+ >>> max_size = 600
105
+ >>> image = torch.rand([3, 300, 800])
106
+ >>> target_size = (448, 1344)
107
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
108
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
109
+
110
+ Example 3: The image will be downscaled from (500, 1000) to (224, 448),
111
+ and padded from (224, 448) to (448, 448).
112
+
113
+ >>> max_size = 600
114
+ >>> image = torch.rand([3, 500, 1000])
115
+ >>> target_size = (448, 488)
116
+ >>> resample = torchvision.transforms.InterpolationMode.BILINEAR
117
+ >>> output = resize_with_pad(image, target_size, resample, max_size)
118
+
119
+ """
120
+
121
+ image_height, image_width = image.shape[-2:]
122
+ image_size = (image_height, image_width)
123
+
124
+ # If target_size requires upscaling, we might want to limit the upscaling to max_size
125
+ if max_size is not None:
126
+ new_target_height = min(max(image_height, max_size), target_size[0])
127
+ new_target_width = min(max(image_width, max_size), target_size[1])
128
+ target_size_resize = (new_target_height, new_target_width)
129
+ else:
130
+ target_size_resize = target_size
131
+
132
+ # resize to target_size while preserving aspect ratio
133
+ new_size_preserving_aspect_ratio = _get_max_res_without_distortion(
134
+ image_size=image_size,
135
+ target_size=target_size_resize,
136
+ )
137
+
138
+ image = F.resize(
139
+ inpt=image,
140
+ size=list(new_size_preserving_aspect_ratio),
141
+ interpolation=resample,
142
+ antialias=True,
143
+ )
144
+
145
+ image = _pad_image_top_left(image=image, target_size=target_size)
146
+
147
+ return image
148
+
149
+
150
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
151
+ def _pad_image_top_left(
152
+ image: torch.Tensor,
153
+ target_size: Tuple[int, int],
154
+ ) -> torch.Tensor:
155
+ """
156
+ Places the image at the top left of the canvas and pads with 0 the right and bottom
157
+ to fit to the target resolution. If target_size < image_size, it will crop the image.
158
+
159
+ Args:
160
+ image (torch.Tensor): The input image tensor in the format [..., H, W].
161
+ target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width].
162
+
163
+ Returns:
164
+ torch.Tensor: The padded image tensor in the format [..., H, W].
165
+ """
166
+
167
+ image_size = image.shape[-2:]
168
+
169
+ height, width = image_size
170
+ target_height, target_width = target_size
171
+
172
+ pad_x = target_width - width
173
+ pad_y = target_height - height
174
+
175
+ padding = [0, 0, pad_x, pad_y]
176
+ return F.pad(inpt=image, padding=padding)
177
+
178
+
179
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py
180
+ def _get_max_res_without_distortion(
181
+ image_size: Tuple[int, int],
182
+ target_size: Tuple[int, int],
183
+ ) -> Tuple[int, int]:
184
+ """
185
+ Determines the maximum resolution to which an image can be resized to without distorting its
186
+ aspect ratio, based on the target resolution.
187
+
188
+ For example, if image_size = (200,400) and target_size = (600,800),
189
+ scale_h = 600/200 = 3
190
+ scale_w = 800/400 = 2
191
+ So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2
192
+
193
+ Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w
194
+
195
+ Args:
196
+ image_size (Tuple[int, int]): The original resolution of the image.
197
+ target_size (Tuple[int, int]): The desired resolution to fit the image into.
198
+ Returns:
199
+ Tuple[int, int]: The optimal dimensions to which the image should be resized.
200
+ Examples:
201
+ >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200))
202
+ (133, 200)
203
+ >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300))
204
+ (450, 337)
205
+ """
206
+
207
+ original_height, original_width = image_size
208
+ target_height, target_width = target_size
209
+
210
+ scale_w = target_width / original_width
211
+ scale_h = target_height / original_height
212
+
213
+ if scale_w < scale_h:
214
+ new_width = target_width
215
+ new_height = min(math.floor(original_height * scale_w), target_height)
216
+ else:
217
+ new_height = target_height
218
+ new_width = min(math.floor(original_width * scale_h), target_width)
219
+
220
+ return new_height, new_width
221
+
222
+
223
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
224
+ def _get_factors(n: int) -> Set[int]:
225
+ """
226
+ Calculate all factors of a given number, i.e. a divisor that leaves no remainder.
227
+
228
+ Args:
229
+ n (int): The number to find factors for.
230
+
231
+ Returns:
232
+ set: A set containing all factors of the number.
233
+
234
+ Examples:
235
+ >>> _get_factors(n=12)
236
+ {1, 2, 3, 4, 6, 12}
237
+ """
238
+ factors_set = set()
239
+
240
+ for i in range(1, int(n**0.5) + 1):
241
+ if n % i == 0:
242
+ factors_set.add(i)
243
+ factors_set.add(n // i)
244
+ return factors_set
245
+
246
+
247
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
248
+ def get_canvas_best_fit(
249
+ image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool
250
+ ) -> Tuple[int, int]:
251
+ """
252
+ Determines the best canvas possible from a list of possible resolutions to
253
+ resize an image to, without distortion.
254
+
255
+ For each possible resolution, calculates the scaling factors for
256
+ width and height, and selects the smallest one, which is the limiting side.
257
+ E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x,
258
+ then the maximum upscaling without distortion is min(2, 1.5) = 1.5.
259
+
260
+ If there are multiple canvases that satisfy the conditions,
261
+ we pick the one with the lowest area to minimize padding.
262
+
263
+ Args:
264
+ image (torch.Tensor): The image we want to fit into a canvas.
265
+ possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
266
+ row represents a possible canvas.
267
+ resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling.
268
+ If False, pick the canvas that minimizes downscaling, including no downscaling at all.
269
+
270
+ Returns:
271
+ Tuple[int, int]: The best resolution to fit the image into.
272
+
273
+ Examples:
274
+ >>> image = torch.rand(3, 200, 300)
275
+ >>> possible_resolutions = torch.tensor([
276
+ ... [224, 672],
277
+ ... [672, 224],
278
+ ... [224, 448],
279
+ ... [448, 224],
280
+ ... [224, 224]
281
+ ... ])
282
+ >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False)
283
+ (224, 448)
284
+
285
+ In the example above, we calculate the scaling factors for each possible resolution
286
+
287
+ >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
288
+ >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
289
+ >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
290
+
291
+ Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest
292
+
293
+ >>> upscaling_options = torch.tensor([1.1200, 1.1200])
294
+ >>> selected_scale = torch.tensor(1.1200)
295
+
296
+ There are two possible options, so we pick the one with the smallest area
297
+
298
+ >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively
299
+ >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area
300
+ """
301
+
302
+ original_height, original_width = image.shape[-2:]
303
+
304
+ # possible resolutions heights/widths
305
+ target_heights, target_widths = (
306
+ possible_resolutions[:, 0],
307
+ possible_resolutions[:, 1],
308
+ )
309
+
310
+ # scaling factors to resize the image without distortion
311
+ scale_w = target_widths / original_width
312
+ scale_h = target_heights / original_height
313
+
314
+ # get limiting side scaling -> no distortion
315
+ scales = torch.where(scale_w > scale_h, scale_h, scale_w)
316
+
317
+ # filter only scales that allow upscaling
318
+ upscaling_options = scales[scales >= 1]
319
+ if len(upscaling_options) > 0:
320
+ if resize_to_max_canvas:
321
+ selected_scale = torch.max(upscaling_options)
322
+ else:
323
+ selected_scale = torch.min(upscaling_options)
324
+ else:
325
+ # no upscaling possible,
326
+ # get the minimum downscaling (max scale for scales<1)
327
+ downscaling_options = scales[scales < 1]
328
+ selected_scale = torch.max(downscaling_options)
329
+
330
+ # get all resolutions that support this scaling factor,
331
+ # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
332
+ chosen_canvas = possible_resolutions[scales == selected_scale]
333
+
334
+ # if there are multiple resolutions,
335
+ # get the one with minimum area to reduce padding
336
+ if len(chosen_canvas) > 1:
337
+ areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
338
+ optimal_idx = torch.argmin(areas)
339
+ optimal_canvas = chosen_canvas[optimal_idx]
340
+ else:
341
+ optimal_canvas = chosen_canvas[0]
342
+
343
+ return tuple(optimal_canvas.tolist())
344
+
345
+
346
+ # NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py
347
+ def find_supported_resolutions(
348
+ max_num_tiles: int, tile_size: int
349
+ ) -> List[Tuple[int, int]]:
350
+ """
351
+ Computes all combinations of resolutions, multiple of tile_size,
352
+ that contain up to max_num_tiles. Useful for when dividing an image into tiles.
353
+
354
+ For example, if we want at most 2 tiles per image, then we can support the
355
+ following resolutions: (1x1, 1x2, 2x1) * tile_size
356
+
357
+ Args:
358
+ max_num_tiles (int): Maximum number of tiles.
359
+ tile_size (int): Size of the side of the tile.
360
+
361
+ Returns:
362
+ List[Tuple[int, int]]: List of possible resolutions as tuples (height, width).
363
+
364
+ Examples:
365
+
366
+ >>> max_num_tiles = 4
367
+ >>> tile_size = 224
368
+ >>> find_supported_resolutions(max_num_tiles, tile_size)
369
+ [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)]
370
+ """
371
+
372
+ # create dictionary {aspect_ratio: [resolution1, ..., resolution n]}
373
+ # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]}
374
+ asp_dict = defaultdict(list)
375
+ for _tile_size in range(max_num_tiles, 0, -1):
376
+ factors = sorted(_get_factors(_tile_size))
377
+ asp_ratios = [(factor, _tile_size // factor) for factor in factors]
378
+ for height, width in asp_ratios:
379
+ ratio_float = height / width
380
+ asp_dict[ratio_float].append((height, width))
381
+
382
+ # get the resolutions multiplied by the tile_size
383
+ possible_resolutions = []
384
+ for ar, resolution in asp_dict.items():
385
+ for height, width in resolution:
386
+ possible_resolutions.append((height * tile_size, width * tile_size))
387
+
388
+ return possible_resolutions
389
+
390
+
391
+ # NOTE Copied from torchtune.data._utils.py
392
+ def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
393
+ """
394
+ Convenience method to load an image in torch.Tensor format from a local file path or remote source.
395
+
396
+ Args:
397
+ image_loc (Union[Path, str]): Local file path or remote source pointing to the image
398
+ which will be loaded in PIL format.
399
+
400
+ Note:
401
+ If loading an image from a remote source, the function expects the URL provided in ``image_loc``
402
+ to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
403
+
404
+ Raises:
405
+ ValueError: If the image cannot be loaded from remote source, **or**
406
+ if the image cannot be opened as a :class:`~torch.Tensor`.
407
+
408
+ Examples:
409
+ >>> # Load from remote source
410
+ >>> image = load_image("https://www.wikipedia.org/en/bird.jpg")
411
+
412
+ >>> # Load from local file path
413
+ >>> image = load_image(Path("/home/user/bird.jpg"))
414
+
415
+ Returns:
416
+ torch.Tensor: The loaded image.
417
+ """
418
+
419
+ # If pointing to remote source, try to load to local
420
+ if isinstance(image_loc, str) and image_loc.startswith("http"):
421
+ try:
422
+ image_loc = request.urlopen(image_loc).read()
423
+ image = torchvision.io.decode_image(
424
+ torch.frombuffer(image_loc, dtype=torch.uint8),
425
+ mode="RGB",
426
+ )
427
+ except Exception as e:
428
+ raise ValueError("Failed to load remote image as torch.Tensor") from e
429
+
430
+ # Open the local image as a Tensor image
431
+ else:
432
+ try:
433
+ image = torchvision.io.decode_image(image_loc, mode="RGB")
434
+ except Exception as e:
435
+ raise ValueError("Failed to load local image as torch.Tensor") from e
436
+
437
+ return image
torchtitan/experiments/simple_fsdp/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## SimpleFSDP
2
+
3
+ This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.
4
+
5
+ ### Enable SimpleFSDP Training
6
+
7
+ ```bash
8
+ CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32
9
+ ```
10
+
11
+ Note: The mixed precision training support is on-going. We set `training.mixed_precision_param` to `float32` for now and will remove it once the integration is completed.
12
+
13
+ ### Composability Support
14
+
15
+ Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:
16
+
17
+ | Feature | Support |
18
+ | :--------: | :--------: |
19
+ |Meta Initialization| ✅ |
20
+ |Activation Checkpointing| ✅ |
21
+ |Mixed Precision Training| 🚧 |
22
+ |Tensor Parallelism| 🚧 |
23
+ |Context Parallelism| ✅ |
24
+ |Pipeline Parallelism| ✅ |
25
+ |Distributed Checkpointing| 🚧 |
26
+ |Float8 Training| ❌ |
27
+
28
+
29
+ ### Citation
30
+
31
+ If you find SimpleFSDP useful, please kindly consider citing the following paper:
32
+
33
+ ```latex
34
+ @article{zhang2024simplefsdp,
35
+ title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile},
36
+ author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco},
37
+ journal={arXiv preprint arXiv:2411.00284},
38
+ year={2024}
39
+ }
40
+ ```
torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.1 kB). View file
 
torchtitan/experiments/simple_fsdp/tests/test_numerics.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import copy
7
+
8
+ import torch
9
+ from torch.distributed._composable.fsdp import fully_shard
10
+
11
+ from torch.testing._internal.common_fsdp import FSDPTest
12
+
13
+ from torchtitan.components.loss import cross_entropy_loss
14
+ from torchtitan.distributed import ParallelDims
15
+ from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
16
+
17
+
18
+ class TestSimpleFSDP(FSDPTest):
19
+ def init_test(self):
20
+ self.optimizer = torch.optim.Adam
21
+ self.loss_fn = cross_entropy_loss
22
+ data_parallel_shard_degree = -1
23
+ if self.mode == "replicate":
24
+ self.dp_mesh_dim_names = ("dp_replicate",)
25
+ data_parallel_replicate_degree = self.world_size
26
+ elif self.mode == "fully_shard":
27
+ self.dp_mesh_dim_names = ("dp_shard_cp",)
28
+ data_parallel_replicate_degree = 1
29
+ elif self.mode == "hybrid_shard":
30
+ self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
31
+ data_parallel_replicate_degree = self.world_size // 2
32
+ else:
33
+ raise ValueError(f"Unsupported mode {mode}")
34
+
35
+ self.parallel_dims = ParallelDims(
36
+ dp_shard=data_parallel_shard_degree,
37
+ dp_replicate=data_parallel_replicate_degree,
38
+ cp=1,
39
+ tp=1,
40
+ pp=1,
41
+ world_size=self.world_size,
42
+ enable_loss_parallel=True,
43
+ )
44
+ self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda")
45
+
46
+ def get_input(self):
47
+ inputs = torch.randn(8, 8).cuda()
48
+ labels = torch.randn(8, 8).cuda()
49
+ model = torch.nn.Linear(8, 8)
50
+ return model, inputs, labels
51
+
52
+ def run_fsdp2(self, model, inputs, labels, epoch=20):
53
+ fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)])
54
+ optim = self.optimizer(model.parameters(), lr=1e-4)
55
+ losses = []
56
+ for _ in range(epoch):
57
+ optim.zero_grad()
58
+ out = model(inputs)
59
+ loss = self.loss_fn(out, labels)
60
+ loss.backward()
61
+ optim.step()
62
+ losses.append(loss)
63
+ return losses
64
+
65
+ def run_simple_fsdp(self, model, inputs, labels, epoch=20):
66
+ model = data_parallel(
67
+ model,
68
+ device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
69
+ mode=self.mode,
70
+ )
71
+ optim = self.optimizer(model.parameters(), lr=1e-4)
72
+ losses = []
73
+ for _ in range(epoch):
74
+ optim.zero_grad()
75
+ out = model(inputs)
76
+ loss = self.loss_fn(out, labels)
77
+ loss.backward()
78
+ optim.step()
79
+ losses.append(loss)
80
+ return losses
81
+
82
+ def test_replicate_convergence(self):
83
+ # unit test for replicate mode
84
+ self.mode = "replicate"
85
+ self.init_test()
86
+ model, inputs, labels = self.get_input()
87
+
88
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
89
+ simple_fsdp_replicate_losses = self.run_simple_fsdp(
90
+ copy.deepcopy(model), inputs, labels
91
+ )
92
+
93
+ for fsdp2_loss, simple_fsdp_replicate_loss in zip(
94
+ fsdp2_losses, simple_fsdp_replicate_losses
95
+ ):
96
+ assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss)
97
+
98
+ def test_fullyshard_convergence(self):
99
+ # unit test for fully_shard mode
100
+ self.mode = "fully_shard"
101
+ self.init_test()
102
+ model, inputs, labels = self.get_input()
103
+
104
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
105
+ simple_fsdp_fullyshard_losses = self.run_simple_fsdp(
106
+ copy.deepcopy(model), inputs, labels
107
+ )
108
+
109
+ for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
110
+ fsdp2_losses, simple_fsdp_fullyshard_losses
111
+ ):
112
+ assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss)
113
+
114
+ def test_hybridshard_convergence(self):
115
+ # unit test for hybrid_shard mode
116
+ self.mode = "hybrid_shard"
117
+ self.init_test()
118
+ model, inputs, labels = self.get_input()
119
+
120
+ fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
121
+ simple_fsdp_hybridshard_losses = self.run_simple_fsdp(
122
+ copy.deepcopy(model), inputs, labels
123
+ )
124
+
125
+ for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
126
+ fsdp2_losses, simple_fsdp_hybridshard_losses
127
+ ):
128
+ assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss)
torchtitan/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (186 Bytes). View file
 
torchtitan/models/llama3/__pycache__/model.cpython-312.pyc ADDED
Binary file (25.9 kB). View file
 
torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
torchtitan/models/llama3/parallelize_llama.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed._composable.replicate import replicate
15
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16
+ checkpoint_wrapper as ptd_checkpoint_wrapper,
17
+ )
18
+
19
+ from torch.distributed.device_mesh import DeviceMesh
20
+ from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
21
+ from torch.distributed.tensor import Replicate, Shard
22
+ from torch.distributed.tensor.parallel import (
23
+ ColwiseParallel,
24
+ parallelize_module,
25
+ PrepareModuleInput,
26
+ RowwiseParallel,
27
+ SequenceParallel,
28
+ )
29
+
30
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
31
+ from torchtitan.distributed import ParallelDims
32
+ from torchtitan.tools.logging import logger
33
+
34
+
35
+ def parallelize_llama(
36
+ model: nn.Module,
37
+ world_mesh: DeviceMesh,
38
+ parallel_dims: ParallelDims,
39
+ job_config: JobConfig,
40
+ ):
41
+ """
42
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
43
+ parallelism to the model.
44
+
45
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
46
+ the model must fit on GPU or CPU memory.
47
+ """
48
+
49
+ if parallel_dims.tp_enabled:
50
+ if (
51
+ job_config.parallelism.enable_async_tensor_parallel
52
+ and not job_config.training.compile
53
+ ):
54
+ raise RuntimeError("Async TP requires --training.compile")
55
+
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ float8_is_rowwise = job_config.float8.recipe_name in (
58
+ "rowwise",
59
+ "rowwise_with_gw_hp",
60
+ )
61
+
62
+ # For now, float8 all-gather with TP is only supported for tensorwise
63
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
64
+ # all-gather happens in high precision.
65
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
66
+
67
+ apply_tp(
68
+ model,
69
+ world_mesh["tp"],
70
+ loss_parallel=parallel_dims.loss_parallel_enabled,
71
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
72
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
73
+ )
74
+
75
+ if job_config.model.use_flex_attn:
76
+ if job_config.activation_checkpoint.mode == "selective":
77
+ raise ValueError(
78
+ "FlexAttention is not compatible with selective AC yet. "
79
+ "See https://github.com/pytorch/pytorch/issues/147879"
80
+ )
81
+
82
+ if parallel_dims.cp_enabled:
83
+ raise ValueError(
84
+ "FlexAttention is not compatible with CP yet. "
85
+ "We are still working on this."
86
+ )
87
+
88
+ if job_config.activation_checkpoint.mode != "none":
89
+ apply_ac(model, job_config.activation_checkpoint)
90
+
91
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
92
+ if job_config.training.compile:
93
+ apply_compile(model)
94
+
95
+ if (
96
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
97
+ ): # apply FSDP or HSDP, potentially with Context Parallel
98
+ if parallel_dims.dp_replicate_enabled:
99
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
100
+ else:
101
+ dp_mesh_dim_names = ("dp_shard_cp",)
102
+
103
+ apply_fsdp(
104
+ model,
105
+ world_mesh[tuple(dp_mesh_dim_names)],
106
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
107
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
108
+ pp_enabled=parallel_dims.pp_enabled,
109
+ cpu_offload=job_config.training.enable_cpu_offload,
110
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
111
+ )
112
+
113
+ if parallel_dims.dp_replicate_enabled:
114
+ logger.info("Applied HSDP to the model")
115
+ else:
116
+ logger.info("Applied FSDP to the model")
117
+
118
+ if parallel_dims.cp_enabled:
119
+ logger.info("Applied Context Parallel to the model")
120
+
121
+ if job_config.training.enable_cpu_offload:
122
+ logger.info("Applied CPU Offloading to the model")
123
+ elif parallel_dims.dp_replicate_enabled:
124
+ if world_mesh.ndim > 1:
125
+ raise RuntimeError("DDP has not supported > 1D parallelism")
126
+ apply_ddp(
127
+ model,
128
+ world_mesh,
129
+ enable_compile=job_config.training.compile,
130
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
131
+ )
132
+
133
+ return model
134
+
135
+
136
+ def apply_tp(
137
+ model: nn.Module,
138
+ tp_mesh: DeviceMesh,
139
+ loss_parallel: bool,
140
+ enable_float8_tensorwise_tp: bool,
141
+ enable_async_tp: bool,
142
+ ):
143
+ """Apply tensor parallelism."""
144
+ # 1. Parallelize the embedding and shard its outputs (which are the first
145
+ # transformer block's inputs)
146
+ # 2. Parallelize the root norm layer over the sequence dim
147
+ # 3. Parallelize the final linear output layer
148
+ parallelize_module(
149
+ model,
150
+ tp_mesh,
151
+ {
152
+ "tok_embeddings": RowwiseParallel(
153
+ input_layouts=Replicate(),
154
+ output_layouts=Shard(1),
155
+ ),
156
+ "norm": SequenceParallel(),
157
+ "output": ColwiseParallel(
158
+ input_layouts=Shard(1),
159
+ output_layouts=Shard(-1) if loss_parallel else Replicate(),
160
+ use_local_output=not loss_parallel,
161
+ ),
162
+ },
163
+ )
164
+
165
+ # Parallel styles used for transformer block linear weights and their
166
+ # inputs may be different for float8 linears with tensorwise scaling.
167
+ if enable_float8_tensorwise_tp:
168
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
169
+ from torchao.float8.float8_tensor_parallel import (
170
+ Float8ColwiseParallel,
171
+ Float8RowwiseParallel,
172
+ PrepareFloat8ModuleInput,
173
+ )
174
+
175
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
176
+ Float8RowwiseParallel,
177
+ Float8ColwiseParallel,
178
+ PrepareFloat8ModuleInput,
179
+ )
180
+ else:
181
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
182
+ RowwiseParallel,
183
+ ColwiseParallel,
184
+ PrepareModuleInput,
185
+ )
186
+
187
+ # Apply tensor + sequence parallelism to every transformer block
188
+ # NOTE: At the cost of model code change, we can accelerate Sequence Parallel
189
+ # by folding (and unfolding) the batch dimension and the sequence dimension.
190
+ # Examples can be found at https://github.com/pytorch/torchtitan/pull/437
191
+ for layer_id, transformer_block in model.layers.items():
192
+ layer_plan = {
193
+ "attention_norm": SequenceParallel(),
194
+ "attention": prepare_module_input(
195
+ input_layouts=(Shard(1), None),
196
+ desired_input_layouts=(Replicate(), None),
197
+ ),
198
+ "attention.wq": colwise_parallel(),
199
+ "attention.wk": colwise_parallel(),
200
+ "attention.wv": colwise_parallel(),
201
+ "attention.wo": rowwise_parallel(output_layouts=Shard(1)),
202
+ "ffn_norm": SequenceParallel(),
203
+ "feed_forward": prepare_module_input(
204
+ input_layouts=(Shard(1),),
205
+ desired_input_layouts=(Replicate(),),
206
+ ),
207
+ "feed_forward.w1": colwise_parallel(),
208
+ "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
209
+ "feed_forward.w3": colwise_parallel(),
210
+ }
211
+
212
+ parallelize_module(
213
+ module=transformer_block,
214
+ device_mesh=tp_mesh,
215
+ parallelize_plan=layer_plan,
216
+ )
217
+
218
+ if enable_async_tp:
219
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
220
+
221
+ torch._inductor.config._micro_pipeline_tp = True
222
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
223
+
224
+ logger.info(
225
+ f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
226
+ "Tensor Parallelism to the model"
227
+ )
228
+
229
+
230
+ # for selective op activation checkpointing
231
+ _save_list = {
232
+ torch.ops.aten.mm.default,
233
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
234
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
235
+ # for low precision training, it's useful to always save
236
+ # the result of max, since the absolute maximum is
237
+ # used to compute the scaling factor for quantization.
238
+ torch.ops.aten.max.default,
239
+ }
240
+
241
+
242
+ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
243
+ valid_ac_modes = ("full", "selective")
244
+ if ac_config.mode not in valid_ac_modes:
245
+ raise ValueError(
246
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
247
+ )
248
+
249
+ if ac_config.mode == "full":
250
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
251
+
252
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
253
+ use_op_sac = ac_config.selective_ac_option == "op"
254
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
255
+ if not use_op_sac and not use_layer_sac:
256
+ raise ValueError(
257
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
258
+ f"Valid options: 'op' or a positive int representing layer frequency"
259
+ )
260
+ if use_op_sac:
261
+ from torch.utils.checkpoint import (
262
+ CheckpointPolicy,
263
+ create_selective_checkpoint_contexts,
264
+ )
265
+
266
+ def _get_custom_policy(meta):
267
+ def _custom_policy(ctx, func, *args, **kwargs):
268
+ mode = "recompute" if ctx.is_recompute else "forward"
269
+ mm_count_key = f"{mode}_mm_count"
270
+ if func == torch.ops.aten.mm.default:
271
+ meta[mm_count_key] += 1
272
+ # Saves output of all compute ops, except every second mm
273
+ to_save = func in _save_list and not (
274
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
275
+ )
276
+ return (
277
+ CheckpointPolicy.MUST_SAVE
278
+ if to_save
279
+ else CheckpointPolicy.PREFER_RECOMPUTE
280
+ )
281
+
282
+ return _custom_policy
283
+
284
+ def selective_checkpointing_context_fn():
285
+ meta = defaultdict(int)
286
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
287
+
288
+ return ptd_checkpoint_wrapper(
289
+ module,
290
+ context_fn=selective_checkpointing_context_fn,
291
+ preserve_rng_state=False,
292
+ )
293
+ elif use_layer_sac:
294
+ # Checkpoint every `ac_freq` of the modules passed to this function
295
+ ac_freq = int(ac_config.selective_ac_option)
296
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
297
+ ptd_checkpoint_wrapper._count += 1
298
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
299
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
300
+ else:
301
+ return module
302
+
303
+
304
+ def apply_ac(model: nn.Module, ac_config):
305
+ """Apply activation checkpointing to the model."""
306
+ for layer_id, transformer_block in model.layers.named_children():
307
+ transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
308
+ model.layers.register_module(layer_id, transformer_block)
309
+
310
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
311
+
312
+
313
+ def apply_compile(model: nn.Module):
314
+ """
315
+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
316
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
317
+ """
318
+ for layer_id, transformer_block in model.layers.named_children():
319
+ transformer_block = torch.compile(transformer_block, fullgraph=True)
320
+ model.layers.register_module(layer_id, transformer_block)
321
+
322
+ logger.info("Compiling each TransformerBlock with torch.compile")
323
+
324
+
325
+ def apply_fsdp(
326
+ model: nn.Module,
327
+ dp_mesh: DeviceMesh,
328
+ param_dtype: torch.dtype,
329
+ reduce_dtype: torch.dtype,
330
+ pp_enabled: bool,
331
+ cpu_offload: bool = False,
332
+ reshard_after_forward_policy: str = "default",
333
+ ):
334
+ """
335
+ Apply data parallelism (via FSDP2) to the model.
336
+
337
+ Args:
338
+ model (nn.Module): The model to apply data parallelism to.
339
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
340
+ param_dtype (torch.dtype): The data type to use for model parameters.
341
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
342
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
343
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
344
+ reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
345
+ Other options: "never", "always".
346
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
347
+ - "always" will enable `reshard_after_forward` for all forward passes.
348
+ - "never" will disable `reshard_after_forward` for all forward passes.
349
+
350
+ """
351
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
352
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
353
+ if cpu_offload:
354
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
355
+
356
+ for layer_id, transformer_block in model.layers.items():
357
+ if reshard_after_forward_policy == "always":
358
+ reshard_after_forward = True
359
+ elif reshard_after_forward_policy == "never":
360
+ reshard_after_forward = False
361
+ elif reshard_after_forward_policy == "default":
362
+ if pp_enabled:
363
+ # For PP, do not reshard after forward to avoid per-microbatch
364
+ # all-gathers, which can be expensive and non-overlapped
365
+ reshard_after_forward = False
366
+ else:
367
+ # As an optimization, do not reshard after forward for the last
368
+ # transformer block since FSDP would prefetch it immediately
369
+ reshard_after_forward = int(layer_id) < len(model.layers) - 1
370
+ else:
371
+ raise ValueError(
372
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
373
+ )
374
+ fully_shard(
375
+ transformer_block,
376
+ **fsdp_config,
377
+ reshard_after_forward=reshard_after_forward,
378
+ )
379
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
380
+
381
+
382
+ def apply_ddp(
383
+ model: nn.Module,
384
+ dp_mesh: DeviceMesh,
385
+ enable_compile: bool,
386
+ enable_compiled_autograd: bool,
387
+ ):
388
+ if enable_compile:
389
+ if enable_compiled_autograd:
390
+ torch._dynamo.config.optimize_ddp = (
391
+ "python_reducer_without_compiled_forward"
392
+ )
393
+ else:
394
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
395
+
396
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
397
+
398
+ logger.info("Applied DDP to the model")
torchtitan/models/llama3/train_configs/llama3_405b.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+ # NOTE: this toml config is a preset for 128 H100 GPUs.
3
+
4
+ [job]
5
+ dump_folder = "./outputs"
6
+ description = "Llama 3 405B training"
7
+
8
+ [profiling]
9
+ enable_profiling = true
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 100
12
+
13
+ [metrics]
14
+ log_freq = 10
15
+ enable_tensorboard = true
16
+ save_tb_folder = "tb"
17
+
18
+ [model]
19
+ name = "llama3"
20
+ flavor = "405B"
21
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
22
+ tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
23
+ converters = "float8"
24
+
25
+ [optimizer]
26
+ name = "AdamW"
27
+ lr = 8e-5
28
+ eps = 1e-8
29
+
30
+ [lr_scheduler]
31
+ warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
32
+
33
+ [training]
34
+ batch_size = 2
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = true
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8 # 8-way TP
45
+ enable_async_tensor_parallel = true
46
+ pipeline_parallel_degree = 1
47
+ context_parallel_degree = 1
48
+
49
+ [checkpoint]
50
+ enable_checkpoint = false
51
+ folder = "checkpoint"
52
+ interval = 500
53
+ model_weights_only = false
54
+ export_dtype = "float32"
55
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
56
+
57
+ [activation_checkpoint]
58
+ mode = 'full' # ['none', 'selective', 'full']
59
+
60
+ [float8]
61
+ enable_fsdp_float8_all_gather = true
62
+ precompute_float8_dynamic_scale_for_fsdp = true
63
+ filter_fqns = "output"
torchtitan/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.28 kB). View file