File size: 3,134 Bytes
2f9e2c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from transformers import PretrainedConfig
import json
class StripedHyenaConfig(PretrainedConfig):
model_type = "stripedhyena"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
num_filters=4096,
inner_mlp_size=14336,
attn_layer_idxs=[],
hyena_layer_idxs=[],
num_layers=32,
tie_embeddings=False,
short_filter_length=3,
num_attention_heads=32,
proj_groups=4,
hyena_filter_groups=1,
split_k0=True,
column_split_hyena=True,
column_split=False,
model_parallel_size=1,
pipe_parallel_size=1,
short_filter_bias=True,
mha_out_proj_bias=False,
qkv_proj_bias=False,
final_norm=True,
use_cache=True,
use_flash_attention_2=True,
use_flash_rmsnorm=True,
use_flash_depthwise=False,
use_flashfft=False,
inference_mode=False,
prefill_style="fft",
max_seqlen=32768,
eps=1e-5,
state_size=2,
rotary_emb_base=500000,
smeared_gqa=False,
make_vocab_size_divisible_by=8,
log_intermediate_values=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_filters = num_filters
self.inner_mlp_size = inner_mlp_size
self.attn_layer_idxs = attn_layer_idxs
self.hyena_layer_idxs = hyena_layer_idxs
self.num_layers = num_layers
self.tie_embeddings = tie_embeddings
self.short_filter_length = short_filter_length
self.num_attention_heads = num_attention_heads
self.proj_groups = proj_groups
self.hyena_filter_groups = hyena_filter_groups
self.split_k0 = split_k0
self.column_split_hyena = column_split_hyena
self.column_split = column_split
self.model_parallel_size = model_parallel_size
self.pipe_parallel_size = pipe_parallel_size
self.short_filter_bias = short_filter_bias
self.mha_out_proj_bias = mha_out_proj_bias
self.qkv_proj_bias = qkv_proj_bias
self.final_norm = final_norm
self.use_cache = use_cache
self.use_flash_attention_2 = use_flash_attention_2
self.use_flash_rmsnorm = use_flash_rmsnorm
self.use_flash_depthwise = use_flash_depthwise
self.use_flashfft = use_flashfft
self.inference_mode = inference_mode
self.prefill_style = prefill_style
self.max_seqlen = max_seqlen
self.eps = eps
self.state_size = state_size
self.rotary_emb_base = rotary_emb_base
self.smeared_gqa = smeared_gqa
self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
self.log_intermediate_values = log_intermediate_values
super().__init__(**kwargs)
def to_dict(self):
return {attr: getattr(self, attr) for attr in self.__dict__}
@classmethod
def from_original_config(cls, config_path, **kwargs):
with open(config_path, "r") as f:
config = json.load(f)
return cls(**config, **kwargs)
|