|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch Mistral model. |
|
""" |
|
from typing import Dict, Optional, Union |
|
import inspect |
|
|
|
import torch |
|
from flash_attn import bert_padding |
|
from flash_attn.flash_attn_interface import ( |
|
flash_attn_varlen_func, |
|
flash_attn_with_kvcache, |
|
) |
|
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding |
|
from nanotron import distributed as dist |
|
from nanotron import logging |
|
from nanotron.config import ParallelismArgs, RecomputeGranularity |
|
from nanotron.generation.generate_store import AttachableStore |
|
from nanotron.logging import log_rank |
|
from nanotron.models import NanotronModel |
|
from nanotron.nn.layer_norm import TritonRMSNorm |
|
from nanotron.parallel import ParallelContext |
|
from nanotron.parallel.parameters import NanotronParameter |
|
from nanotron.parallel.pipeline_parallel.block import ( |
|
PipelineBlock, |
|
TensorPointer, |
|
) |
|
from nanotron.parallel.pipeline_parallel.p2p import P2P |
|
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy |
|
from nanotron.parallel.tensor_parallel.nn import ( |
|
TensorParallelColumnLinear, |
|
TensorParallelEmbedding, |
|
TensorParallelLinearMode, |
|
TensorParallelRowLinear, |
|
) |
|
from nanotron.random import RandomStates |
|
from nanotron.utils import checkpoint_method |
|
from nanotron.nn.activations import ACT2FN |
|
from torch import nn |
|
|
|
from config_mistral_7b import MistralConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters) |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, dim: int, end: int, theta: float = 10000.0): |
|
super().__init__() |
|
assert dim % 2 == 0 |
|
self.dim = dim |
|
self.end = end |
|
self.theta = theta |
|
|
|
|
|
self.freqs_cis: torch.Tensor |
|
self._initialized_buffer = False |
|
|
|
def init_rotary_embeddings(self): |
|
if self._initialized_buffer is True: |
|
|
|
return |
|
self.register_buffer( |
|
"freqs_cis", |
|
torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"), |
|
persistent=False, |
|
) |
|
assert self.freqs_cis.device.type == "cuda" |
|
|
|
if self.freqs_cis.dtype != torch.float: |
|
self.freqs_cis = self.freqs_cis.to(torch.float) |
|
assert self.freqs_cis.dtype == torch.float |
|
freqs = 1.0 / ( |
|
self.theta |
|
** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim) |
|
) |
|
t = torch.arange(self.end, device="cuda") |
|
freqs = torch.outer(t, freqs).float() |
|
complex_freqs = torch.polar(torch.ones_like(freqs), freqs) |
|
freqs = torch.view_as_real(complex_freqs) |
|
self.freqs_cis.copy_(freqs) |
|
self._initialized_buffer = True |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
position_ids: Optional[torch.LongTensor], |
|
): |
|
batch_size, seq_length, num_heads, inner_dim = x.shape |
|
while ( |
|
position_ids is not None and position_ids[-1, -1] >= self.end |
|
) or seq_length >= self.end: |
|
self.end *= 2 |
|
self._initialized_buffer = False |
|
if self._initialized_buffer is False: |
|
print(f"Initializing rotary embeddings with end={self.end}") |
|
self.init_rotary_embeddings() |
|
dtype = x.dtype |
|
assert inner_dim % 2 == 0 |
|
x = x.view( |
|
batch_size, seq_length, num_heads, inner_dim // 2, 2 |
|
) |
|
if x.dtype == torch.bfloat16: |
|
x = x.float() |
|
complex_x = torch.view_as_complex(x) |
|
if position_ids is None: |
|
freqs_cis = self.freqs_cis[None, :seq_length, None, :] |
|
else: |
|
|
|
if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: |
|
raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}") |
|
freqs_cis = self.freqs_cis[position_ids][:, :, None, :] |
|
complex_freqs = torch.view_as_complex(freqs_cis) |
|
x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim) |
|
return x_out.type(dtype) |
|
|
|
|
|
class GLUActivation(nn.Module): |
|
def __init__(self, act_fn_name: str): |
|
super().__init__() |
|
self.act = ACT2FN[act_fn_name] |
|
|
|
def forward(self, merged_states: torch.Tensor): |
|
gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) |
|
return self.act(gate_states) * up_states |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
parallel_config: Optional[ParallelismArgs], |
|
tp_pg: dist.ProcessGroup, |
|
): |
|
super().__init__() |
|
|
|
|
|
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE |
|
tp_linear_async_communication = ( |
|
parallel_config.tp_linear_async_communication if parallel_config is not None else False |
|
) |
|
|
|
gate_up_contiguous_chunks = ( |
|
config.intermediate_size, |
|
config.intermediate_size, |
|
) |
|
self.gate_up_proj = TensorParallelColumnLinear( |
|
config.hidden_size, |
|
2 * config.intermediate_size, |
|
pg=tp_pg, |
|
mode=tp_mode, |
|
bias=False, |
|
async_communication=tp_linear_async_communication, |
|
contiguous_chunks=gate_up_contiguous_chunks, |
|
) |
|
|
|
self.down_proj = TensorParallelRowLinear( |
|
config.intermediate_size, |
|
config.hidden_size, |
|
pg=tp_pg, |
|
mode=tp_mode, |
|
bias=False, |
|
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, |
|
) |
|
|
|
self.split_silu_mul = GLUActivation(config.hidden_act) |
|
|
|
def forward(self, hidden_states): |
|
merged_states = self.gate_up_proj(hidden_states) |
|
hidden_states = self.down_proj(self.split_silu_mul(merged_states)) |
|
return {"hidden_states": hidden_states} |
|
|
|
|
|
class CoreAttention(nn.Module): |
|
def __init__(self, config: MistralConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int): |
|
super().__init__() |
|
|
|
assert ( |
|
config.hidden_size % config.num_attention_heads == 0 |
|
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}." |
|
self.d_qk = config.hidden_size // config.num_attention_heads |
|
self.d_v = config.hidden_size // config.num_attention_heads |
|
self.dropout = config.attn_pdrop |
|
|
|
self.checkpoint_attention = False |
|
|
|
if config.sliding_window_size is not None: |
|
assert ( |
|
_flash_supports_window_size |
|
), "Current version of flash-attn doesn't support sliding window: `pip install flash-attn>=2.3`" |
|
self.sliding_window_size = config.sliding_window_size |
|
|
|
@checkpoint_method(attr_name="checkpoint_attention") |
|
def forward( |
|
self, |
|
query_states: torch.Tensor, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
q_sequence_mask: torch.Tensor, |
|
kv_sequence_mask: torch.Tensor, |
|
): |
|
|
|
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) |
|
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) |
|
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) |
|
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) |
|
|
|
|
|
|
|
causal = False if q_sequence_mask.shape[1] == 1 else True |
|
attn_output = flash_attn_varlen_func( |
|
q=query_states, |
|
k=key_states, |
|
v=value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=q_sequence_mask.shape[1], |
|
max_seqlen_k=kv_sequence_mask.shape[1], |
|
dropout_p=self.dropout if self.training else 0.0, |
|
softmax_scale=None, |
|
causal=causal, |
|
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1), |
|
return_attn_probs=False, |
|
) |
|
|
|
return attn_output |
|
|
|
|
|
def pad_to_right(tensor, mask, new_tensor=None): |
|
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states) |
|
Args: |
|
tensor: (batch_size, seqlen, d1, d2) |
|
mask: (batch_size, seqlen) |
|
new_tensor: (batch_size, new_tensor_seqlen, d1, d2) |
|
Returns: |
|
new_tensor: (batch_size, new_tensor_seqlen, d1, d2) |
|
right_padded_mask: (batch_size, seqlen) |
|
""" |
|
|
|
unpad_seqlens = mask.sum(1) |
|
|
|
max_seqlen = mask.shape[1] |
|
|
|
|
|
indices = torch.arange(max_seqlen, device=mask.device) |
|
|
|
right_padded_mask = indices < unpad_seqlens[:, None] |
|
|
|
useful_values = tensor[mask] |
|
|
|
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor |
|
|
|
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values |
|
return new_tensor, right_padded_mask |
|
|
|
|
|
class CausalSelfAttention(nn.Module, AttachableStore): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
parallel_config: Optional[ParallelismArgs], |
|
tp_pg: dist.ProcessGroup, |
|
layer_idx: int, |
|
): |
|
super().__init__() |
|
|
|
assert ( |
|
config.num_attention_heads % tp_pg.size() == 0 |
|
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})." |
|
try: |
|
assert ( |
|
config.num_key_value_heads % tp_pg.size() == 0 |
|
), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})." |
|
except AttributeError: |
|
log_rank( |
|
"WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads", |
|
logger=logger, |
|
level=logging.WARNING, |
|
rank=0, |
|
) |
|
|
|
config.num_key_value_heads = config.num_attention_heads |
|
assert ( |
|
config.num_attention_heads % config.num_key_value_heads == 0 |
|
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})." |
|
self.n_local_q_heads = config.num_attention_heads // tp_pg.size() |
|
self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size() |
|
self.n_repeats = config.num_attention_heads // config.num_key_value_heads |
|
self.is_gqa = config.num_attention_heads != config.num_key_value_heads |
|
self.d_qk = config.hidden_size // config.num_attention_heads |
|
self.d_v = config.hidden_size // config.num_attention_heads |
|
self.d_model = config.hidden_size |
|
|
|
|
|
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE |
|
tp_linear_async_communication = ( |
|
parallel_config.tp_linear_async_communication if parallel_config is not None else False |
|
) |
|
|
|
|
|
|
|
qkv_contiguous_chunks = ( |
|
config.num_attention_heads * self.d_qk, |
|
config.num_key_value_heads * self.d_qk, |
|
config.num_key_value_heads * self.d_qk, |
|
) |
|
self.qkv_proj = TensorParallelColumnLinear( |
|
self.d_model, |
|
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, |
|
pg=tp_pg, |
|
mode=tp_mode, |
|
bias=False, |
|
async_communication=tp_linear_async_communication, |
|
contiguous_chunks=qkv_contiguous_chunks, |
|
) |
|
|
|
self.rotary_embedding = RotaryEmbedding( |
|
dim=self.d_qk, |
|
end=config.max_position_embeddings, |
|
theta=config.rope_theta |
|
) |
|
|
|
|
|
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) |
|
|
|
self.o_proj = TensorParallelRowLinear( |
|
config.num_attention_heads * self.d_qk, |
|
self.d_model, |
|
pg=tp_pg, |
|
mode=tp_mode, |
|
bias=False, |
|
async_communication=tp_linear_async_communication, |
|
) |
|
|
|
self.attention = CoreAttention( |
|
config, |
|
parallel_config=parallel_config, |
|
layer_idx=layer_idx, |
|
) |
|
|
|
self.prefill_kv_len = ( |
|
config.max_position_embeddings |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
sequence_mask, |
|
): |
|
qkv_states = self.qkv_proj( |
|
hidden_states |
|
) |
|
q_length, batch_size, _ = qkv_states.shape |
|
|
|
if self.is_gqa: |
|
query_states, key_states, value_states = torch.split( |
|
qkv_states, |
|
[ |
|
self.n_local_q_heads * self.d_qk, |
|
self.n_local_kv_heads * self.d_qk, |
|
self.n_local_kv_heads * self.d_qk, |
|
], |
|
dim=-1, |
|
) |
|
|
|
query_states = ( |
|
query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk) |
|
) |
|
key_states = ( |
|
key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) |
|
) |
|
value_states = ( |
|
value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) |
|
) |
|
else: |
|
query_states, key_states, value_states = ( |
|
qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk) |
|
.permute(2, 1, 0, 3, 4) |
|
.contiguous() |
|
) |
|
|
|
store = self.get_local_store() |
|
if store is not None: |
|
|
|
assert key_states.requires_grad is False |
|
assert value_states.requires_grad is False |
|
print("Using store") |
|
if "position_offsets" in store: |
|
old_position_offsets = store["position_offsets"] |
|
position_ids = old_position_offsets[:, None] + sequence_mask |
|
else: |
|
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1 |
|
position_offsets = position_ids[:, -1] |
|
|
|
|
|
|
|
old_rotary_embed_end = self.rotary_embedding.end |
|
query_states = self.rotary_embedding(query_states, position_ids=position_ids) |
|
key_states = self.rotary_embedding(key_states, position_ids=position_ids) |
|
|
|
if "key" not in store: |
|
|
|
|
|
|
|
|
|
assert ~( |
|
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) |
|
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing" |
|
|
|
|
|
k_cache = torch.zeros( |
|
( |
|
batch_size, |
|
self.prefill_kv_len, |
|
self.n_local_kv_heads, |
|
self.d_qk, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
) |
|
v_cache = torch.zeros( |
|
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
) |
|
|
|
|
|
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( |
|
query_states, |
|
sequence_mask, |
|
) |
|
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( |
|
key_states, sequence_mask |
|
) |
|
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) |
|
|
|
output_unpad = flash_attn_varlen_func( |
|
q=query_unpad, |
|
k=key_unpad, |
|
v=value_unpad, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
dropout_p=0.0, |
|
softmax_scale=None, |
|
causal=True, |
|
return_attn_probs=False, |
|
) |
|
|
|
attention_output = bert_padding.pad_input( |
|
output_unpad, indices_q, batch_size, q_length |
|
) |
|
|
|
pad_to_right(key_states, sequence_mask, new_tensor=k_cache) |
|
pad_to_right(value_states, sequence_mask, new_tensor=v_cache) |
|
|
|
else: |
|
|
|
|
|
k_cache = store["key"] |
|
v_cache = store["value"] |
|
|
|
|
|
|
|
if self.rotary_embedding.end > old_rotary_embed_end: |
|
k_cache = torch.cat( |
|
[ |
|
k_cache, |
|
torch.zeros( |
|
( |
|
batch_size, |
|
self.rotary_embedding.end - old_rotary_embed_end, |
|
self.n_local_kv_heads, |
|
self.d_qk, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
v_cache = torch.cat( |
|
[ |
|
v_cache, |
|
torch.zeros( |
|
( |
|
batch_size, |
|
self.rotary_embedding.end - old_rotary_embed_end, |
|
self.n_local_kv_heads, |
|
self.d_v, |
|
), |
|
dtype=query_states.dtype, |
|
device=query_states.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
assert ( |
|
k_cache.shape[1] == self.rotary_embedding.end |
|
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" |
|
assert ( |
|
v_cache.shape[1] == self.rotary_embedding.end |
|
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" |
|
|
|
|
|
query_states = query_states.view( |
|
batch_size, q_length, self.n_local_q_heads, self.d_qk |
|
) |
|
kv_length = key_states.shape[1] |
|
key_states = key_states.view( |
|
batch_size, kv_length, self.n_local_kv_heads, self.d_qk |
|
) |
|
value_states = value_states.view( |
|
batch_size, kv_length, self.n_local_kv_heads, self.d_v |
|
) |
|
|
|
attention_output = flash_attn_with_kvcache( |
|
query_states, |
|
k_cache, |
|
v_cache, |
|
key_states, |
|
value_states, |
|
rotary_cos=None, |
|
rotary_sin=None, |
|
|
|
cache_seqlens=position_offsets.contiguous(), |
|
softmax_scale=None, |
|
causal=True, |
|
rotary_interleaved=False, |
|
) |
|
|
|
store.update( |
|
{ |
|
"key": k_cache, |
|
"value": v_cache, |
|
"position_offsets": position_offsets, |
|
} |
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) |
|
|
|
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() |
|
query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) |
|
|
|
key_states, value_states = torch.split(key_value_states, 1, dim=2) |
|
|
|
q_sequence_mask = sequence_mask |
|
kv_sequence_mask = sequence_mask |
|
|
|
kv_length = key_states.shape[1] |
|
|
|
|
|
query_states = query_states.view( |
|
batch_size * q_length, self.n_local_q_heads, self.d_qk |
|
) |
|
|
|
key_states = key_states.view( |
|
batch_size * kv_length, self.n_local_kv_heads, self.d_qk |
|
) |
|
value_states = value_states.view( |
|
batch_size * kv_length, self.n_local_kv_heads, self.d_v |
|
) |
|
|
|
attention_output = self.attention( |
|
query_states=query_states, |
|
key_states=key_states, |
|
value_states=value_states, |
|
q_sequence_mask=q_sequence_mask, |
|
kv_sequence_mask=kv_sequence_mask, |
|
) |
|
|
|
attention_output = ( |
|
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) |
|
) |
|
output = self.o_proj(attention_output) |
|
|
|
return {"hidden_states": output, "sequence_mask": sequence_mask} |
|
|
|
|
|
class MistralDecoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
parallel_config: Optional[ParallelismArgs], |
|
tp_pg: dist.ProcessGroup, |
|
layer_idx: int, |
|
): |
|
super().__init__() |
|
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.attn = CausalSelfAttention( |
|
config=config, |
|
parallel_config=parallel_config, |
|
tp_pg=tp_pg, |
|
layer_idx=layer_idx, |
|
) |
|
|
|
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Union[torch.Tensor, TensorPointer], |
|
sequence_mask: Union[torch.Tensor, TensorPointer], |
|
) -> Dict[str, Union[torch.Tensor, TensorPointer]]: |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) |
|
hidden_states = output["hidden_states"] |
|
hidden_states = hidden_states + residual |
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] |
|
hidden_states = hidden_states + residual |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"sequence_mask": output["sequence_mask"], |
|
} |
|
|
|
|
|
class Embedding(nn.Module, AttachableStore): |
|
def __init__(self, tp_pg: dist.ProcessGroup, config: MistralConfig, parallel_config: Optional[ParallelismArgs]): |
|
super().__init__() |
|
self.token_embedding = TensorParallelEmbedding( |
|
num_embeddings=config.vocab_size, |
|
embedding_dim=config.hidden_size, |
|
padding_idx=config.pad_token_id, |
|
pg=tp_pg, |
|
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, |
|
) |
|
self.pg = tp_pg |
|
|
|
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): |
|
store = self.get_local_store() |
|
if store is not None: |
|
if "past_length" in store: |
|
past_length = store["past_length"] |
|
else: |
|
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) |
|
|
|
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) |
|
|
|
store["past_length"] = past_length + cumsum_mask[:, -1] |
|
|
|
|
|
input_ids = input_ids.transpose(0, 1) |
|
input_embeds = self.token_embedding(input_ids) |
|
return {"input_embeds": input_embeds} |
|
|
|
|
|
class MistralModel(nn.Module): |
|
"""Build pipeline graph""" |
|
|
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
parallel_context: ParallelContext, |
|
parallel_config: Optional[ParallelismArgs], |
|
): |
|
super().__init__() |
|
|
|
|
|
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) |
|
self.config = config |
|
self.parallel_config = parallel_config |
|
self.parallel_context = parallel_context |
|
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE |
|
tp_linear_async_communication = ( |
|
parallel_config.tp_linear_async_communication if parallel_config is not None else False |
|
) |
|
|
|
self.token_position_embeddings = PipelineBlock( |
|
p2p=self.p2p, |
|
module_builder=Embedding, |
|
module_kwargs={ |
|
"tp_pg": parallel_context.tp_pg, |
|
"config": config, |
|
"parallel_config": parallel_config, |
|
}, |
|
module_input_keys={"input_ids", "input_mask"}, |
|
module_output_keys={"input_embeds"}, |
|
) |
|
|
|
self.decoder = nn.ModuleList( |
|
[ |
|
PipelineBlock( |
|
p2p=self.p2p, |
|
module_builder=MistralDecoderLayer, |
|
module_kwargs={ |
|
"config": config, |
|
"parallel_config": parallel_config, |
|
"tp_pg": parallel_context.tp_pg, |
|
"layer_idx": layer_idx, |
|
}, |
|
module_input_keys={"hidden_states", "sequence_mask"}, |
|
module_output_keys={"hidden_states", "sequence_mask"}, |
|
) |
|
for layer_idx in range(config.num_hidden_layers) |
|
] |
|
) |
|
|
|
self.final_layer_norm = PipelineBlock( |
|
p2p=self.p2p, |
|
module_builder=TritonRMSNorm, |
|
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, |
|
module_input_keys={"input"}, |
|
module_output_keys={"hidden_states"}, |
|
) |
|
|
|
self.lm_head = PipelineBlock( |
|
p2p=self.p2p, |
|
|
|
module_builder=TensorParallelColumnLinear, |
|
module_kwargs={ |
|
"in_features": config.hidden_size, |
|
"out_features": config.vocab_size, |
|
"pg": parallel_context.tp_pg, |
|
"bias": False, |
|
|
|
"mode": self.tp_mode, |
|
"async_communication": tp_linear_async_communication, |
|
}, |
|
module_input_keys={"x"}, |
|
module_output_keys={"logits"}, |
|
) |
|
|
|
self.cast_to_fp32 = PipelineBlock( |
|
p2p=self.p2p, |
|
module_builder=lambda: lambda x: x.float(), |
|
module_kwargs={}, |
|
module_input_keys={"x"}, |
|
module_output_keys={"output"}, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[torch.Tensor, TensorPointer], |
|
input_mask: Union[torch.Tensor, TensorPointer], |
|
): |
|
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] |
|
|
|
def forward_with_hidden_states( |
|
self, |
|
input_ids: Union[torch.Tensor, TensorPointer], |
|
input_mask: Union[torch.Tensor, TensorPointer], |
|
): |
|
|
|
|
|
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) |
|
|
|
hidden_encoder_states = { |
|
"hidden_states": output["input_embeds"], |
|
"sequence_mask": input_mask, |
|
} |
|
for encoder_block in self.decoder: |
|
hidden_encoder_states = encoder_block(**hidden_encoder_states) |
|
|
|
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] |
|
|
|
sharded_logits = self.lm_head(x=hidden_states)["logits"] |
|
|
|
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] |
|
|
|
return fp32_sharded_logits, hidden_states |
|
|
|
def get_block_compute_costs(self): |
|
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" |
|
model_config = self.config |
|
d_ff = model_config.intermediate_size |
|
d_qkv = model_config.hidden_size // model_config.num_attention_heads |
|
block_compute_costs = { |
|
|
|
MistralDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size |
|
+ 3 * d_ff * model_config.hidden_size, |
|
|
|
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, |
|
} |
|
return block_compute_costs |
|
|
|
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): |
|
"""Get flops per second for a given model""" |
|
world_size = self.parallel_context.world_pg.size() |
|
try: |
|
num_key_values_heads = self.config.num_key_value_heads |
|
except AttributeError: |
|
num_key_values_heads = self.config.num_attention_heads |
|
|
|
model_flops, hardware_flops = get_flops( |
|
num_layers=self.config.num_hidden_layers, |
|
hidden_size=self.config.hidden_size, |
|
num_heads=self.config.num_attention_heads, |
|
num_key_value_heads=num_key_values_heads, |
|
vocab_size=self.config.vocab_size, |
|
ffn_hidden_size=self.config.intermediate_size, |
|
seq_len=sequence_length, |
|
batch_size=global_batch_size, |
|
recompute_granularity=self.parallel_config.recompute_granularity, |
|
) |
|
|
|
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) |
|
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) |
|
return model_flops_per_s, hardware_flops_per_s |
|
|
|
|
|
@torch.jit.script |
|
def masked_mean(loss, label_mask, dtype): |
|
|
|
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() |
|
|
|
|
|
class Loss(nn.Module): |
|
def __init__(self, tp_pg: dist.ProcessGroup): |
|
super().__init__() |
|
self.tp_pg = tp_pg |
|
|
|
def forward( |
|
self, |
|
sharded_logits: torch.Tensor, |
|
label_ids: torch.Tensor, |
|
label_mask: torch.Tensor, |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
loss = sharded_cross_entropy( |
|
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float |
|
).transpose(0, 1) |
|
|
|
loss = masked_mean(loss, label_mask, dtype=torch.float) |
|
|
|
|
|
return {"loss": loss} |
|
|
|
|
|
class MistralForTraining(NanotronModel): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
parallel_context: ParallelContext, |
|
parallel_config: Optional[ParallelismArgs], |
|
random_states: Optional[RandomStates] = None, |
|
): |
|
super().__init__() |
|
import warnings |
|
|
|
self.model = MistralModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) |
|
self.loss = PipelineBlock( |
|
p2p=self.model.p2p, |
|
module_builder=Loss, |
|
module_kwargs={"tp_pg": parallel_context.tp_pg}, |
|
module_input_keys={ |
|
"sharded_logits", |
|
"label_ids", |
|
"label_mask", |
|
}, |
|
module_output_keys={"loss"}, |
|
) |
|
self.parallel_context = parallel_context |
|
self.config = config |
|
self.parallel_config = parallel_config |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[torch.Tensor, TensorPointer], |
|
input_mask: Union[torch.Tensor, TensorPointer], |
|
label_ids: Union[torch.Tensor, TensorPointer], |
|
label_mask: Union[torch.Tensor, TensorPointer], |
|
) -> Dict[str, Union[torch.Tensor, TensorPointer]]: |
|
sharded_logits = self.model( |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
) |
|
loss = self.loss( |
|
sharded_logits=sharded_logits, |
|
label_ids=label_ids, |
|
label_mask=label_mask, |
|
)["loss"] |
|
return {"loss": loss} |
|
|
|
@torch.no_grad() |
|
def init_model_randomly(self, init_method, scaled_init_method): |
|
"""Initialize model parameters randomly. |
|
Args: |
|
init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ |
|
scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ |
|
|
|
Note: |
|
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` |
|
""" |
|
model = self |
|
initialized_parameters = set() |
|
|
|
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} |
|
|
|
module_id_to_prefix[id(model)] = "" |
|
|
|
for module_name, module in model.named_modules(): |
|
if isinstance(module, TensorParallelColumnLinear): |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { |
|
name for name, _ in module.named_parameters() |
|
} |
|
for param_name, param in module.named_parameters(): |
|
assert isinstance(param, NanotronParameter) |
|
if param.is_tied: |
|
tied_info = param.get_tied_info() |
|
full_param_name = tied_info.get_full_name_from_module_id_to_prefix( |
|
module_id_to_prefix=module_id_to_prefix |
|
) |
|
else: |
|
full_param_name = f"{module_name}.{param_name}" |
|
|
|
if full_param_name in initialized_parameters: |
|
|
|
continue |
|
|
|
if "weight" == param_name: |
|
init_method(param) |
|
elif "bias" == param_name: |
|
param.zero_() |
|
else: |
|
raise ValueError(f"Who the fuck is {param_name}?") |
|
|
|
assert full_param_name not in initialized_parameters |
|
initialized_parameters.add(full_param_name) |
|
elif isinstance(module, TensorParallelRowLinear): |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { |
|
name for name, _ in module.named_parameters() |
|
} |
|
for param_name, param in module.named_parameters(): |
|
assert isinstance(param, NanotronParameter) |
|
if param.is_tied: |
|
tied_info = param.get_tied_info() |
|
full_param_name = tied_info.get_full_name_from_module_id_to_prefix( |
|
module_id_to_prefix=module_id_to_prefix |
|
) |
|
else: |
|
full_param_name = f"{module_name}.{param_name}" |
|
|
|
if full_param_name in initialized_parameters: |
|
|
|
continue |
|
|
|
if "weight" == param_name: |
|
scaled_init_method(param) |
|
elif "bias" == param_name: |
|
param.zero_() |
|
else: |
|
raise ValueError(f"Who the fuck is {param_name}?") |
|
|
|
assert full_param_name not in initialized_parameters |
|
initialized_parameters.add(full_param_name) |
|
elif isinstance(module, TritonRMSNorm): |
|
assert {"weight"} == {name for name, _ in module.named_parameters()} |
|
for param_name, param in module.named_parameters(): |
|
assert isinstance(param, NanotronParameter) |
|
if param.is_tied: |
|
tied_info = param.get_tied_info() |
|
full_param_name = tied_info.get_full_name_from_module_id_to_prefix( |
|
module_id_to_prefix=module_id_to_prefix |
|
) |
|
else: |
|
full_param_name = f"{module_name}.{param_name}" |
|
|
|
if full_param_name in initialized_parameters: |
|
|
|
continue |
|
|
|
if "weight" == param_name: |
|
|
|
param.fill_(1) |
|
elif "bias" == param_name: |
|
param.zero_() |
|
else: |
|
raise ValueError(f"Who the fuck is {param_name}?") |
|
|
|
assert full_param_name not in initialized_parameters |
|
initialized_parameters.add(full_param_name) |
|
elif isinstance(module, TensorParallelEmbedding): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert {"weight"} == {name for name, _ in module.named_parameters()} |
|
|
|
assert isinstance(module.weight, NanotronParameter) |
|
if module.weight.is_tied: |
|
tied_info = module.weight.get_tied_info() |
|
full_param_name = tied_info.get_full_name_from_module_id_to_prefix( |
|
module_id_to_prefix=module_id_to_prefix |
|
) |
|
else: |
|
full_param_name = f"{module_name}.weight" |
|
|
|
if full_param_name in initialized_parameters: |
|
|
|
continue |
|
|
|
init_method(module.weight) |
|
assert full_param_name not in initialized_parameters |
|
initialized_parameters.add(full_param_name) |
|
|
|
assert initialized_parameters == { |
|
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) |
|
if param.is_tied |
|
else name |
|
for name, param in model.named_parameters() |
|
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" |
|
|
|
def get_block_compute_costs(self): |
|
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" |
|
return self.model.get_block_compute_costs() |
|
|
|
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): |
|
"""Get flops per second for a given model""" |
|
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) |
|
|
|
|
|
def get_flops( |
|
num_layers, |
|
hidden_size, |
|
num_heads, |
|
vocab_size, |
|
seq_len, |
|
kv_channels=None, |
|
ffn_hidden_size=None, |
|
batch_size=1, |
|
recompute_granularity=None, |
|
glu_activation=False, |
|
): |
|
"""Counts flops in an decoder-only model |
|
Args: |
|
num_layers: number of decoder layers |
|
hidden_size: hidden size of the model |
|
num_heads: number of heads in the model |
|
num_key_value_heads: number of key/value heads in the model |
|
ffn_hidden_size: hidden size of the FFN |
|
vocab_size: size of the vocabulary |
|
seq_len: sequence length of the decoder |
|
batch_size: batch size |
|
recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info. |
|
Returns: |
|
model_flops: flops in the model (should be independent of the hardware and model implementation) |
|
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf |
|
""" |
|
if kv_channels is None: |
|
assert hidden_size % num_heads == 0 |
|
kv_channels = hidden_size // num_heads |
|
if ffn_hidden_size is None: |
|
ffn_hidden_size = 4 * hidden_size |
|
|
|
|
|
|
|
|
|
|
|
decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels |
|
|
|
decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels |
|
|
|
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len |
|
|
|
|
|
|
|
|
|
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels |
|
|
|
|
|
decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size |
|
|
|
|
|
decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size |
|
if glu_activation: |
|
|
|
|
|
|
|
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size |
|
|
|
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size |
|
|
|
decoder_flops_fwd = ( |
|
decoder_q_proj_flops_fwd |
|
+ decoder_kv_proj_flops_fwd |
|
+ decoder_qk_logits_flops_fwd |
|
+ decoder_v_logits_flops_fwd |
|
+ decoder_attn_out_flops_fwd |
|
+ decoder_ffn_1_flops_fwd |
|
+ decoder_ffn_2_flops_fwd |
|
) |
|
|
|
|
|
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size |
|
|
|
|
|
|
|
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) |
|
|
|
if recompute_granularity is None: |
|
hardware_flops = model_flops |
|
elif recompute_granularity is RecomputeGranularity.FULL: |
|
|
|
hardware_flops = model_flops + decoder_flops_fwd |
|
elif recompute_granularity is RecomputeGranularity.SELECTIVE: |
|
|
|
|
|
recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd |
|
hardware_flops = model_flops + recomputed_decoder_flops |
|
else: |
|
raise ValueError("recompute_granularity must be one of 'full' or 'selective'") |
|
|
|
return model_flops, hardware_flops |
|
|