fix: force correct dtype in HF load
Browse files- modeling_hyena.py +3 -0
modeling_hyena.py
CHANGED
@@ -46,6 +46,9 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
|
|
46 |
self.vocab_size = vocab_size
|
47 |
self.post_init()
|
48 |
|
|
|
|
|
|
|
49 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|
50 |
self.backbone.gradient_checkpointing = enable
|
51 |
|
|
|
46 |
self.vocab_size = vocab_size
|
47 |
self.post_init()
|
48 |
|
49 |
+
def post_init(self):
|
50 |
+
self.backbone.to_bfloat16_except_poles_residues()
|
51 |
+
|
52 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|
53 |
self.backbone.gradient_checkpointing = enable
|
54 |
|