support eager attention
Browse files- config.json +2 -0
- modeling_aria.py +1 -0
- vision_encoder.py +0 -1
config.json
CHANGED
@@ -30,8 +30,10 @@
|
|
30 |
},
|
31 |
"torch_dtype": "bfloat16",
|
32 |
"transformers_version": "4.45.0",
|
|
|
33 |
"vision_config": {
|
34 |
"_flash_attn_2_enabled": true,
|
|
|
35 |
"architectures": [
|
36 |
"AriaVisionModel"
|
37 |
],
|
|
|
30 |
},
|
31 |
"torch_dtype": "bfloat16",
|
32 |
"transformers_version": "4.45.0",
|
33 |
+
"_attn_implementation": "flash_attention_2",
|
34 |
"vision_config": {
|
35 |
"_flash_attn_2_enabled": true,
|
36 |
+
"_attn_implementation": "flash_attention_2",
|
37 |
"architectures": [
|
38 |
"AriaVisionModel"
|
39 |
],
|
modeling_aria.py
CHANGED
@@ -133,6 +133,7 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
133 |
def __init__(self, config: AriaConfig):
|
134 |
super().__init__(config)
|
135 |
|
|
|
136 |
self.vision_tower = AriaVisionModel(config.vision_config)
|
137 |
self.multi_modal_projector = build_mm_projector(config)
|
138 |
self.vocab_size = config.text_config.vocab_size
|
|
|
133 |
def __init__(self, config: AriaConfig):
|
134 |
super().__init__(config)
|
135 |
|
136 |
+
config.vision_config._attn_implementation = config._attn_implementation
|
137 |
self.vision_tower = AriaVisionModel(config.vision_config)
|
138 |
self.multi_modal_projector = build_mm_projector(config)
|
139 |
self.vocab_size = config.text_config.vocab_size
|
vision_encoder.py
CHANGED
@@ -38,7 +38,6 @@ class AriaVisionConfig(SiglipVisionConfig):
|
|
38 |
**kwargs,
|
39 |
):
|
40 |
super().__init__(**kwargs)
|
41 |
-
self._attn_implementation = "flash_attention_2"
|
42 |
|
43 |
|
44 |
class IdentityOp(torch.nn.Module):
|
|
|
38 |
**kwargs,
|
39 |
):
|
40 |
super().__init__(**kwargs)
|
|
|
41 |
|
42 |
|
43 |
class IdentityOp(torch.nn.Module):
|