Enable flash_attention_2 support since the underlying Mistral model supports it
Browse files- modeling_eurus_rm.py +2 -0
modeling_eurus_rm.py
CHANGED
@@ -5,6 +5,8 @@ from typing import Optional, List
|
|
5 |
|
6 |
class EurusRewardModel(PreTrainedModel):
|
7 |
config_class = MistralConfig
|
|
|
|
|
8 |
def __init__(self, config):
|
9 |
super().__init__(config)
|
10 |
self.model = MistralModel(config)
|
|
|
5 |
|
6 |
class EurusRewardModel(PreTrainedModel):
|
7 |
config_class = MistralConfig
|
8 |
+
_supports_flash_attn_2 = True
|
9 |
+
|
10 |
def __init__(self, config):
|
11 |
super().__init__(config)
|
12 |
self.model = MistralModel(config)
|