Merge pull request #30 from LightricksResearch/fix-no-flash-attention
Browse filesmodel: fix flash attention enabling - do not check device type at this point
xora/models/transformers/attention.py
CHANGED
@@ -179,15 +179,14 @@ class BasicTransformerBlock(nn.Module):
|
|
179 |
self._chunk_size = None
|
180 |
self._chunk_dim = 0
|
181 |
|
182 |
-
def set_use_tpu_flash_attention(self
|
183 |
r"""
|
184 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
185 |
attention kernel.
|
186 |
"""
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
self.attn2.set_use_tpu_flash_attention(device)
|
191 |
|
192 |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
193 |
# Sets chunk feed-forward
|
@@ -508,12 +507,11 @@ class Attention(nn.Module):
|
|
508 |
processor = AttnProcessor2_0()
|
509 |
self.set_processor(processor)
|
510 |
|
511 |
-
def set_use_tpu_flash_attention(self
|
512 |
r"""
|
513 |
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
514 |
"""
|
515 |
-
|
516 |
-
self.use_tpu_flash_attention = True
|
517 |
|
518 |
def set_processor(self, processor: "AttnProcessor") -> None:
|
519 |
r"""
|
|
|
179 |
self._chunk_size = None
|
180 |
self._chunk_dim = 0
|
181 |
|
182 |
+
def set_use_tpu_flash_attention(self):
|
183 |
r"""
|
184 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
185 |
attention kernel.
|
186 |
"""
|
187 |
+
self.use_tpu_flash_attention = True
|
188 |
+
self.attn1.set_use_tpu_flash_attention()
|
189 |
+
self.attn2.set_use_tpu_flash_attention()
|
|
|
190 |
|
191 |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
192 |
# Sets chunk feed-forward
|
|
|
507 |
processor = AttnProcessor2_0()
|
508 |
self.set_processor(processor)
|
509 |
|
510 |
+
def set_use_tpu_flash_attention(self):
|
511 |
r"""
|
512 |
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
513 |
"""
|
514 |
+
self.use_tpu_flash_attention = True
|
|
|
515 |
|
516 |
def set_processor(self, processor: "AttnProcessor") -> None:
|
517 |
r"""
|
xora/models/transformers/transformer3d.py
CHANGED
@@ -160,13 +160,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
160 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
161 |
attention kernel.
|
162 |
"""
|
163 |
-
logger.info("
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
for block in self.transformer_blocks:
|
169 |
-
block.set_use_tpu_flash_attention(self.device.type)
|
170 |
|
171 |
def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
|
172 |
def _basic_init(module):
|
|
|
160 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
161 |
attention kernel.
|
162 |
"""
|
163 |
+
logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
|
164 |
+
self.use_tpu_flash_attention = True
|
165 |
+
# push config down to the attention modules
|
166 |
+
for block in self.transformer_blocks:
|
167 |
+
block.set_use_tpu_flash_attention()
|
|
|
|
|
168 |
|
169 |
def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
|
170 |
def _basic_init(module):
|