|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
try: |
|
from mamba_ssm.ops.selective_scan_interface import ( |
|
selective_scan_ref, |
|
selective_scan_fn, |
|
mamba_inner_fn, |
|
) |
|
from causal_conv1d import causal_conv1d_fn |
|
import einops |
|
except ModuleNotFoundError: |
|
print( |
|
"Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, \ |
|
or directly from https://github.com/state-spaces/mamba" |
|
) |
|
pass |
|
|
|
from megatron.model.norms import get_norm |
|
from megatron import mpu |
|
|
|
|
|
class ParallelMambaBlock(nn.Module): |
|
def __init__( |
|
self, |
|
neox_args, |
|
init_method, |
|
output_layer_init_method, |
|
): |
|
super().__init__() |
|
|
|
self.neox_args = neox_args |
|
|
|
dtype = { |
|
"fp16": torch.float16, |
|
"bf16": torch.bfloat16, |
|
"fp32": torch.float32, |
|
}[neox_args.precision] |
|
self.precision = dtype |
|
factory_kwargs = {"device": torch.cuda.current_device(), "dtype": dtype} |
|
|
|
assert not ( |
|
neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion |
|
), "Mamba fused inner fn and bias in x_proj not compatible!" |
|
|
|
assert ( |
|
neox_args.intermediate_size == None or neox_args.expansion_factor == None |
|
), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" |
|
|
|
|
|
self.d_model = neox_args.hidden_size |
|
self.d_state = 16 |
|
self.d_conv = 4 |
|
if neox_args.intermediate_size: |
|
self.d_inner = neox_args.intermediate_size |
|
else: |
|
self.expand = ( |
|
neox_args.expansion_factor if neox_args.expansion_factor else 2 |
|
) |
|
self.d_inner = int(self.expand * self.d_model) |
|
self.dt_rank = math.ceil(self.d_model / 16) |
|
self.dt_scale = 1.0 |
|
|
|
self.dt_init = "random" |
|
self.dt_min, self.dt_max, self.dt_init_floor = 0.001, 0.1, 1e-4 |
|
assert self.dt_init in ["constant", "random"] |
|
|
|
|
|
world_size = mpu.get_model_parallel_world_size() |
|
self.d_inner_per_rank = mpu.divide(self.d_inner, world_size) |
|
|
|
if neox_args.mamba_inner_func_fusion and world_size > 1: |
|
|
|
|
|
self.reduce = mpu.mappings.reduce_from_model_parallel_region |
|
|
|
|
|
self.in_proj = mpu.ColumnParallelLinear( |
|
neox_args=neox_args, |
|
input_size=self.d_model, |
|
output_size=self.d_inner * 2, |
|
gather_output=False, |
|
init_method=init_method, |
|
skip_bias_add=not neox_args.mamba_use_bias_in_linears, |
|
bias=neox_args.mamba_use_bias_in_linears, |
|
) |
|
|
|
|
|
self.conv1d = nn.Conv1d( |
|
in_channels=self.d_inner_per_rank, |
|
out_channels=self.d_inner_per_rank, |
|
bias=neox_args.mamba_use_bias_in_conv, |
|
kernel_size=self.d_conv, |
|
groups=self.d_inner_per_rank, |
|
padding=self.d_conv - 1, |
|
**factory_kwargs, |
|
) |
|
|
|
|
|
self.conv1d.to(self.precision) |
|
|
|
self.act_fn = F.silu |
|
|
|
|
|
|
|
|
|
self.x_proj = mpu.RowParallelLinear( |
|
neox_args=neox_args, |
|
input_size=self.d_inner, |
|
output_size=self.dt_rank + self.d_state * 2, |
|
input_is_parallel=True, |
|
init_method=init_method, |
|
skip_bias_add=not neox_args.mamba_use_bias_in_linears, |
|
parallel_output=True, |
|
bias=neox_args.mamba_use_bias_in_linears, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dt_proj = nn.Linear( |
|
self.dt_rank, self.d_inner_per_rank, bias=True, **factory_kwargs |
|
) |
|
|
|
|
|
dt_init_std = (self.dt_rank**-0.5) * self.dt_scale |
|
if self.dt_init == "constant": |
|
nn.init.constant_(self.dt_proj.weight, dt_init_std) |
|
elif self.dt_init == "random": |
|
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
dt = torch.exp( |
|
torch.rand(self.d_inner_per_rank, **factory_kwargs) |
|
* (math.log(self.dt_max) - math.log(self.dt_min)) |
|
+ math.log(self.dt_min) |
|
).clamp(min=self.dt_init_floor) |
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
with torch.no_grad(): |
|
self.dt_proj.bias.copy_(inv_dt) |
|
|
|
|
|
A = einops.repeat( |
|
torch.arange( |
|
1, |
|
self.d_state + 1, |
|
dtype=torch.float32, |
|
device=torch.cuda.current_device(), |
|
), |
|
"n -> d n", |
|
d=self.d_inner_per_rank, |
|
).contiguous() |
|
A_log = torch.log(A).to( |
|
torch.float32 |
|
) |
|
self.A_log = nn.Parameter(A_log) |
|
self.A_log._no_weight_decay = ( |
|
True |
|
) |
|
|
|
|
|
if self.neox_args.mamba_selective_fp32_params: |
|
self.A_log._deepspeed_no_cast = True |
|
|
|
|
|
self.D = nn.Parameter( |
|
torch.ones( |
|
self.d_inner_per_rank, |
|
device=torch.cuda.current_device(), |
|
dtype=torch.float32, |
|
) |
|
).to( |
|
torch.float32 |
|
) |
|
self.D._no_weight_decay = ( |
|
True |
|
) |
|
|
|
|
|
if self.neox_args.mamba_selective_fp32_params: |
|
self.D._deepspeed_no_cast = True |
|
|
|
|
|
|
|
|
|
|
|
self.out_proj = mpu.RowParallelLinear( |
|
neox_args=neox_args, |
|
input_size=self.d_inner, |
|
output_size=self.d_model, |
|
input_is_parallel=True, |
|
init_method=output_layer_init_method, |
|
skip_bias_add=not neox_args.mamba_use_bias_in_linears, |
|
bias=neox_args.mamba_use_bias_in_linears, |
|
parallel_output=False, |
|
) |
|
|
|
def selective_scan( |
|
self, |
|
x, |
|
dt, |
|
A, |
|
B, |
|
C, |
|
D, |
|
z=None, |
|
delta_bias=None, |
|
delta_softplus=True, |
|
): |
|
|
|
if not self.neox_args.mamba_selective_scan_fusion: |
|
y = selective_scan_ref( |
|
u=x, |
|
delta=dt, |
|
A=A, |
|
B=B, |
|
C=C, |
|
D=D, |
|
z=z, |
|
delta_bias=delta_bias, |
|
delta_softplus=delta_softplus, |
|
return_last_state=False, |
|
) |
|
else: |
|
y = selective_scan_fn( |
|
x, |
|
dt, |
|
A, |
|
B, |
|
C, |
|
D=D, |
|
z=z, |
|
delta_bias=delta_bias, |
|
delta_softplus=delta_softplus, |
|
return_last_state=False, |
|
) |
|
|
|
return y |
|
|
|
def forward(self, hidden_states): |
|
""" """ |
|
|
|
|
|
assert self.training, "Mamba in NeoX does not support inference!" |
|
|
|
|
|
seqlen, batch, dim = hidden_states.shape |
|
|
|
|
|
xz, _ = self.in_proj(hidden_states) |
|
xz = einops.rearrange(xz, "l b d -> b d l") |
|
|
|
A = -torch.exp(self.A_log.float()) |
|
|
|
if self.neox_args.mamba_inner_func_fusion: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = mamba_inner_fn( |
|
xz, |
|
self.conv1d.weight, |
|
|
|
|
|
self.conv1d.bias.to(self.precision) |
|
if self.conv1d.bias is not None |
|
else self.conv1d.bias, |
|
self.x_proj.weight, |
|
self.dt_proj.weight, |
|
self.out_proj.weight, |
|
self.out_proj.bias, |
|
A, |
|
None, |
|
None, |
|
self.D.float(), |
|
delta_bias=self.dt_proj.bias.float(), |
|
delta_softplus=True, |
|
) |
|
if getattr(self, "reduce", None): |
|
|
|
|
|
|
|
|
|
out = self.reduce(out) |
|
|
|
out = einops.rearrange(out, "b l h -> l b h") |
|
|
|
return out |
|
|
|
x, z = xz.chunk(2, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
if not self.neox_args.mamba_causal_conv_fusion: |
|
self.conv1d.to(self.precision) |
|
x = self.act_fn(self.conv1d(x)[..., :seqlen]) |
|
else: |
|
|
|
x = causal_conv1d_fn( |
|
x=x, |
|
weight=einops.rearrange(self.conv1d.weight, "d 1 w -> d w"), |
|
bias=self.conv1d.bias.to(self.precision) |
|
if self.conv1d.bias is not None |
|
else self.conv1d.bias, |
|
activation="silu", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_dbl, _ = self.x_proj(einops.rearrange(x, "b d l -> (b l) d")) |
|
|
|
dt, B, C = torch.split( |
|
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 |
|
) |
|
|
|
|
|
dt = self.dt_proj.weight @ dt.t() |
|
dt = einops.rearrange(dt, "d (b l) -> b d l", l=seqlen) |
|
|
|
|
|
B = einops.rearrange(B, "(b l) d_state -> b d_state l", l=seqlen).contiguous() |
|
C = einops.rearrange(C, "(b l) d_state -> b d_state l", l=seqlen).contiguous() |
|
|
|
|
|
y = self.selective_scan( |
|
x, |
|
dt, |
|
A, |
|
B, |
|
C, |
|
self.D.float(), |
|
z=z, |
|
delta_bias=self.dt_proj.bias.float(), |
|
delta_softplus=True, |
|
) |
|
|
|
|
|
|
|
|
|
y = einops.rearrange(y, "b d l -> b l d") |
|
|
|
out, _ = self.out_proj(y) |
|
|
|
out = einops.rearrange(out, "b l h -> l b h") |
|
|
|
return out |
|
|
|
|
|
class ParallelMambaResidualLayer(nn.Module): |
|
""" |
|
Pre-norm Mamba Block with residual connection. No parallelism yet supported. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
neox_args, |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
): |
|
super().__init__() |
|
|
|
self.layer_number = layer_number |
|
|
|
|
|
norm, eps = get_norm(neox_args) |
|
|
|
self.norm = norm(neox_args.hidden_size, eps=eps) |
|
|
|
self.mixer = ParallelMambaBlock( |
|
neox_args=neox_args, |
|
init_method=init_method, |
|
output_layer_init_method=output_layer_init_method, |
|
) |
|
|
|
def forward(self, x, attention_mask=None, layer_past=None): |
|
|
|
|
|
|
|
residual = x |
|
|
|
hidden_states = self.mixer(self.norm(x)) |
|
|
|
return hidden_states + residual |
|
|
|
|
|
class ParallelMambaResidualLayerPipe(ParallelMambaResidualLayer): |
|
"""Extends MambaResidualLayer to forward attention_mask through the pipeline. DeepSpeed requires this.""" |
|
|
|
def forward(self, args): |
|
assert ( |
|
len(args) == 2 |
|
), "MambaResidualLayerPipe expects 2 arguments - hidden_states and attention_mask" |
|
hidden_states, attention_mask = args |
|
|
|
return super().forward(hidden_states, attention_mask), attention_mask |
|
|