Text Generation
Transformers
Safetensors
English
stripedhyena
custom_code
Zymrael's picture
chore: add checkpoint import
b48efba
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli
# Note: MP and PP utilities are removed for ease of use and editing.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from .utils import print_rank_0, column_split
from .cache import InferenceParams, RecurrentInferenceParams
from .engine import HyenaInferenceEngine
from .layers import (
RMSNorm,
ParallelGatedMLP,
VocabParallelEmbedding,
)
try:
from flash_attn.modules.mha import MHA
except ImportError:
"flash_attn not installed"
class AttentionBlock(nn.Module):
def __init__(self, config, layer_idx) -> None:
super().__init__()
self.config = config
self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
self.layer_idx = layer_idx
self.proj_groups = config.get("proj_groups", 1)
dtype = config.get("attn_block_dtype", torch.bfloat16)
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
self.num_attention_heads = config.num_attention_heads
self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
self.counter = 0
self.inner_mha_cls = MHA(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
num_heads_kv=config.num_attention_heads // self.proj_groups,
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
qkv_proj_bias=config.get("qkv_proj_bias", True),
rotary_emb_base=config.get("rotary_emb_base", 10000),
causal=True,
layer_idx=layer_idx,
out_proj_bias=config.get("mha_out_proj_bias", True),
use_flash_attn=self.config.use_flash_attn,
).to(dtype=dtype)
if self.config.get("smeared_gqa", False):
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
self.inner_mha_cls.rotary_emb.register_buffer(
"inv_freq", self.inner_mha_cls.rotary_emb.inv_freq
)
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
if (
type(padding_mask) == torch.Tensor
): # workaround for masking bug in FA. This works because Wqkv does not have bias
# and attention scores will be also automatically zeroed.
u = u * padding_mask[..., None]
u = (
self.inner_mha_cls(
self.pre_norm(u),
inference_params=inference_params,
)
+ u
)
if type(padding_mask) == torch.Tensor: # guard against bias
u = u * padding_mask[..., None]
u = self.mlp(self.post_norm(u)) + u
return u, None
class ParallelHyenaFilter(nn.Module):
def __init__(self, config, layer_idx) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
self.use_flashfft = config.get("use_flashfft", False)
self.state_size = config.state_size
self.hidden_size = config.hidden_size
self.num_filters = config.num_filters
self.inference_mode = config.get("inference_mode", True)
self.counter = 0
self.column_split_hyena = config.get("column_split_hyena", True)
assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
self.D = nn.Parameter(torch.zeros(self.hidden_size))
# attention heads are not used except to split post short_filter
# projections in the same way as the checkpoint
self.num_attention_heads = config.num_attention_heads
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
# after preprocessing here we can save the new checkpoint
self.short_filter_length = config.short_filter_length
self.short_filter_weight = nn.Parameter(
torch.randn(3 * config.hidden_size, 1, config.short_filter_length)
)
self.short_filter_bias = (
nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
)
self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
self.use_flash_depthwise = config.get("use_flash_depthwise", False)
self.data_dtype = None
if self.use_flash_depthwise:
self.fir_fn = FlashDepthwiseConv1d(
channels=3 * self.hidden_size,
kernel_size=self.short_filter_length,
padding=self.short_filter_length - 1,
weights=self.short_filter_weight,
bias=self.short_filter_bias,
device=None,
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
)
else:
self.fir_fn = F.conv1d
self.fftconv_fn = None
self.long_fir_threshold = config.get("long_fir_threshold", None)
if self.long_fir_threshold is not None:
assert (
self.use_flashfft is False
), "long_fir_threshold not compatible with fused flashfft"
self.num_systems = self.hidden_size // self.hyena_filter_groups
self.poles = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
self.h = None
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
if (
inference_params is not None
and self.layer_idx in inference_params.fir_state_dict.keys()
):
return self.sequential_forward(u, inference_params)
else:
return self.parallel_forward(u, inference_params, padding_mask)
def parallel_forward(self, u, inference_params=None, padding_mask=None):
L = u.shape[1]
z_pre, fir_state = self.engine.parallel_fir(
self.fir_fn,
u,
self.short_filter_weight,
self.short_filter_bias,
L,
fir_length=self.short_filter_length,
inference_params=inference_params,
padding_mask=padding_mask,
)
if inference_params:
inference_params.fir_state_dict[self.layer_idx] = fir_state
if self.h is None:
h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
else:
h = self.h
filter_dtype = self.h.dtype
if self.hyena_filter_groups > 1:
h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
# if inference_params is not None, we plan to perform generation:
# prefilling for the IIR portion of the filter is handled by the engine.
dims = (
self.hidden_size,
self.num_attention_heads,
self.hidden_size_per_attention_head,
self.state_size,
self.hyena_filter_groups,
)
y = self.engine.parallel_iir(
z_pre,
h,
self.D,
L,
t=self.t,
poles=self.poles,
dims=dims,
inference_params=inference_params,
layer_idx=self.layer_idx,
prefill_style=self.config.get("prefill_style", "fft"),
use_flashfft=self.use_flashfft,
fftconv_fn=self.fftconv_fn,
column_split_hyena=self.column_split_hyena,
long_fir_threshold=self.long_fir_threshold,
padding_mask=padding_mask,
)
return y, inference_params
def sequential_forward(self, u, inference_params):
if self.data_dtype is None:
self.data_dtype = u.dtype
if len(u.shape) > 2:
u = u[:, -1]
fir_state, iir_state = (
inference_params.fir_state_dict[self.layer_idx],
inference_params.state_dict[self.layer_idx],
)
z_pre, fir_state = self.engine.step_fir(
u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
)
x2, x1, v = (
column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
if self.column_split_hyena
else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
)
y, iir_state = self.engine.step_iir(
x2,
x1,
v,
self.D,
self.residues,
self.poles,
iir_state,
iir_groups=self.hyena_filter_groups,
)
inference_params.fir_state_dict[self.layer_idx] = fir_state
inference_params.state_dict[self.layer_idx] = iir_state
y = y.to(dtype=self.data_dtype)
return y[:, None], inference_params
def update_time(self, L, device):
"""
Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
If L is greater than the length of the previous batch, then the time vector is
reinitialized. Otherwise, the time vector is truncated from cache.
"""
if not hasattr(self, "t"):
self.t = torch.arange(L, device=device)[None, None]
elif self.t.shape[-1] < L:
self.t = torch.arange(L, device=device)[None, None]
else:
self.t = self.t[..., :L]
def compute_filter(self, L, device):
self.update_time(L, device)
filter_dtype = torch.float32
residues, log_poles = (
torch.view_as_complex(self.residues.to(filter_dtype)),
torch.view_as_complex(self.poles.to(filter_dtype)).log(),
)
h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
return h, filter_dtype, log_poles, residues
class ParallelGatedConvBlock(nn.Module):
def __init__(self, config, layer_idx) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
dtype = config.get("hyena_block_dtype", torch.float32)
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(
dtype=dtype
)
self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
z = self.projections(self.pre_norm(u))
if type(padding_mask) == torch.Tensor: # guard against bias
z = z * padding_mask[..., None]
z, inference_params = self.filter(
z, inference_params=inference_params, padding_mask=padding_mask
)
u = self.out_filter_dense(z) + u
if type(padding_mask) == torch.Tensor: # guard against bias
u = u * padding_mask[..., None]
u = self.mlp(self.post_norm(u)) + u
return u, inference_params
def get_block(config, layer_idx, flash_fft=None):
if layer_idx in config.attn_layer_idxs:
return AttentionBlock(config, layer_idx)
elif layer_idx in config.hyena_layer_idxs:
block = ParallelGatedConvBlock(config, layer_idx)
if config.get("use_flashfft", "False"):
block.filter.fftconv_fn = flash_fft
return block
else:
raise NotImplementedError
class StripedHyena(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding_layer = VocabParallelEmbedding(config)
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
self.gradient_checkpointing = False
if config.get("use_flashfft", "False"):
raise NotImplementedError("Please use standalone SH code for other custom kernels")
else:
self.flash_fft = None
self.blocks = nn.ModuleList(
get_block(config, layer_idx, flash_fft=self.flash_fft)
for layer_idx in range(config.num_layers)
)
def forward(self, x, inference_params_dict=None, padding_mask=None):
L = x.shape[1]
x = self.embedding_layer.embed(x)
if inference_params_dict is not None:
x, inference_params_dict_out = self.stateful_forward(
x,
inference_params_dict=inference_params_dict,
)
else:
x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
x = self.norm(x)
x = self.unembed.unembed(x)
return x, inference_params_dict_out
def stateful_forward(self, x, inference_params_dict=None):
for block_idx, block in enumerate(self.blocks):
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
inference_params = inference_params_dict[block_name]
x, _ = block(x, inference_params=inference_params)
return x, inference_params_dict
def stateless_forward(self, x, padding_mask=None):
if type(padding_mask) == torch.Tensor:
x = x * padding_mask[..., None]
for block_idx, block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, inference_params=None, padding_mask=padding_mask)
return custom_forward
x, _ = checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
return x, None
def initialize_inference_params(self):
print_rank_0("Initializing inference params...")
inference_params_dict = {
"mha": InferenceParams(
max_seqlen=self.config.get("max_seqlen", 8192),
max_batch_size=self.config.get("max_batch_size", 1),
seqlen_offset=0,
),
"hyena": RecurrentInferenceParams(
fir_filter_length=self.config.short_filter_length,
state_dim=self.config.state_size,
seqlen_offset=0,
),
}
return inference_params_dict
def precompute_filters(self, L, device):
for block_idx, block in enumerate(self.blocks):
if type(block) == ParallelGatedConvBlock:
if type(block.filter) == ParallelHyenaFilter:
L = block.filter.long_fir_threshold or L
print_rank_0(f"Precomputing filters, L={L}...")
filter_dtype = torch.float16 if L >= 2048 else torch.float32
block.filter._set_time(L, device)
residues, poles = (
torch.view_as_complex(block.filter.residues.to(torch.float16)),
torch.view_as_complex(block.filter.poles.to(torch.float16)),
)
block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
block.filter.h = block.filter.h.to(dtype=filter_dtype)
def load_poles_residues(self, path):
"Load different poles and residues for each layer."
for block_idx, block in enumerate(self.blocks):
if type(block) == ParallelGatedConvBlock:
if type(block.filter) == ParallelHyenaFilter:
print(f"Loading poles and residues for block {block_idx}")
poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
poles = torch.view_as_real(poles)
residues = torch.load(
path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu"
)
residues = torch.view_as_real(residues)
poles = poles.permute(1, 0, 2).unsqueeze(-2)
residues = residues.permute(1, 0, 2).unsqueeze(-2)
block.filter.poles = nn.Parameter(poles)
block.filter.residues = nn.Parameter(residues)
def to_bfloat16_except_poles_residues(self):
"""Convert all parameters to bfloat16 except for the poles and residues.
Particularly important for longer prompts.
"""
for k, p in self.named_parameters():
if "poles" not in k and "residues" not in k:
p.data = p.data.to(torch.bfloat16)