michael-guenther commited on
Commit
290e593
1 Parent(s): 807ba34

support-cpu (#2)

Browse files

- set use_flash_attn if not available (71ef01733c24797743bbc24b9c39661cb4a132e2)
- use getattr function (695207d827c7094bf9b0f7d7c048692b0633488e)

Files changed (2) hide show
  1. mha.py +0 -2
  2. modeling_xlm_roberta.py +16 -3
mha.py CHANGED
@@ -10,8 +10,6 @@ import torch
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
13
- from flash_attn.utils.distributed import get_dim_for_local_rank
14
-
15
  try:
16
  from flash_attn import (
17
  flash_attn_kvpacked_func,
 
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
 
 
13
  try:
14
  from flash_attn import (
15
  flash_attn_kvpacked_func,
modeling_xlm_roberta.py CHANGED
@@ -1,6 +1,5 @@
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
-
4
  # Copyright (c) 2022, Tri Dao.
5
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
6
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
@@ -8,6 +7,7 @@
8
 
9
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
 
 
11
  import logging
12
  import re
13
  from collections import OrderedDict
@@ -65,8 +65,21 @@ except ImportError:
65
  logger = logging.getLogger(__name__)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
- use_flash_attn = getattr(config, "use_flash_attn", False)
70
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
71
  rotary_kwargs = {}
72
  if config.position_embedding_type == "rotary":
@@ -169,7 +182,7 @@ def _init_weights(module, initializer_range=0.02):
169
  class XLMRobertaEncoder(nn.Module):
170
  def __init__(self, config: XLMRobertaFlashConfig):
171
  super().__init__()
172
- self.use_flash_attn = getattr(config, "use_flash_attn", False)
173
  self.layers = nn.ModuleList(
174
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
175
  )
 
1
  # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
  # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
 
3
  # Copyright (c) 2022, Tri Dao.
4
  # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
 
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
 
10
+ import importlib.util
11
  import logging
12
  import re
13
  from collections import OrderedDict
 
65
  logger = logging.getLogger(__name__)
66
 
67
 
68
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
69
+ if not getattr(config, "use_flash_attn", False):
70
+ return False
71
+ if not torch.cuda.is_available():
72
+ return False
73
+ if importlib.util.find_spec("flash_attn") is None:
74
+ logger.warning(
75
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
76
+ )
77
+ return False
78
+ return True
79
+
80
+
81
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
82
+ use_flash_attn = get_use_flash_attn(config)
83
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
84
  rotary_kwargs = {}
85
  if config.position_embedding_type == "rotary":
 
182
  class XLMRobertaEncoder(nn.Module):
183
  def __init__(self, config: XLMRobertaFlashConfig):
184
  super().__init__()
185
+ self.use_flash_attn = get_use_flash_attn(config)
186
  self.layers = nn.ModuleList(
187
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
188
  )