aria-dev commited on
Commit
006dc0f
1 Parent(s): 90fb8b4

support eager attention

Browse files
Files changed (3) hide show
  1. config.json +2 -0
  2. modeling_aria.py +1 -0
  3. vision_encoder.py +0 -1
config.json CHANGED
@@ -30,8 +30,10 @@
30
  },
31
  "torch_dtype": "bfloat16",
32
  "transformers_version": "4.45.0",
 
33
  "vision_config": {
34
  "_flash_attn_2_enabled": true,
 
35
  "architectures": [
36
  "AriaVisionModel"
37
  ],
 
30
  },
31
  "torch_dtype": "bfloat16",
32
  "transformers_version": "4.45.0",
33
+ "_attn_implementation": "flash_attention_2",
34
  "vision_config": {
35
  "_flash_attn_2_enabled": true,
36
+ "_attn_implementation": "flash_attention_2",
37
  "architectures": [
38
  "AriaVisionModel"
39
  ],
modeling_aria.py CHANGED
@@ -133,6 +133,7 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
133
  def __init__(self, config: AriaConfig):
134
  super().__init__(config)
135
 
 
136
  self.vision_tower = AriaVisionModel(config.vision_config)
137
  self.multi_modal_projector = build_mm_projector(config)
138
  self.vocab_size = config.text_config.vocab_size
 
133
  def __init__(self, config: AriaConfig):
134
  super().__init__(config)
135
 
136
+ config.vision_config._attn_implementation = config._attn_implementation
137
  self.vision_tower = AriaVisionModel(config.vision_config)
138
  self.multi_modal_projector = build_mm_projector(config)
139
  self.vocab_size = config.text_config.vocab_size
vision_encoder.py CHANGED
@@ -38,7 +38,6 @@ class AriaVisionConfig(SiglipVisionConfig):
38
  **kwargs,
39
  ):
40
  super().__init__(**kwargs)
41
- self._attn_implementation = "flash_attention_2"
42
 
43
 
44
  class IdentityOp(torch.nn.Module):
 
38
  **kwargs,
39
  ):
40
  super().__init__(**kwargs)
 
41
 
42
 
43
  class IdentityOp(torch.nn.Module):