winglian commited on
Commit
93dbd18
1 Parent(s): 463565d

Enable flash_attention_2 support since the underlying Mistral model supports it

Browse files
Files changed (1) hide show
  1. modeling_eurus_rm.py +2 -0
modeling_eurus_rm.py CHANGED
@@ -5,6 +5,8 @@ from typing import Optional, List
5
 
6
  class EurusRewardModel(PreTrainedModel):
7
  config_class = MistralConfig
 
 
8
  def __init__(self, config):
9
  super().__init__(config)
10
  self.model = MistralModel(config)
 
5
 
6
  class EurusRewardModel(PreTrainedModel):
7
  config_class = MistralConfig
8
+ _supports_flash_attn_2 = True
9
+
10
  def __init__(self, config):
11
  super().__init__(config)
12
  self.model = MistralModel(config)