Add explicit error for flash + alibi
Browse files- mosaic_gpt.py +3 -0
mosaic_gpt.py
CHANGED
@@ -31,6 +31,9 @@ class MosaicGPT(PreTrainedModel):
|
|
31 |
def __init__(self, config: MosaicGPTConfig):
|
32 |
super().__init__(config)
|
33 |
|
|
|
|
|
|
|
34 |
self.attn_impl = config.attn_impl
|
35 |
self.prefix_lm = config.prefix_lm
|
36 |
self.attn_uses_sequence_id = config.attn_uses_sequence_id
|
|
|
31 |
def __init__(self, config: MosaicGPTConfig):
|
32 |
super().__init__(config)
|
33 |
|
34 |
+
if config.attn_impl == 'flash' and config.alibi:
|
35 |
+
raise RuntimeError("ALiBi is not supported with flash attention. Please use triton or torch.")
|
36 |
+
|
37 |
self.attn_impl = config.attn_impl
|
38 |
self.prefix_lm = config.prefix_lm
|
39 |
self.attn_uses_sequence_id = config.attn_uses_sequence_id
|