Use try-except for flash_attn

#5
Files changed (1) hide show
  1. modeling_deepseek.py +3 -3
modeling_deepseek.py CHANGED
@@ -48,7 +48,6 @@ from transformers.pytorch_utils import (
48
  from transformers.utils import (
49
  add_start_docstrings,
50
  add_start_docstrings_to_model_forward,
51
- is_flash_attn_2_available,
52
  is_flash_attn_greater_or_equal_2_10,
53
  logging,
54
  replace_return_docstrings,
@@ -58,10 +57,11 @@ from .configuration_deepseek import DeepseekV2Config
58
  import torch.distributed as dist
59
  import numpy as np
60
 
61
- if is_flash_attn_2_available():
62
  from flash_attn import flash_attn_func, flash_attn_varlen_func
63
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
64
-
 
65
 
66
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
67
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
48
  from transformers.utils import (
49
  add_start_docstrings,
50
  add_start_docstrings_to_model_forward,
 
51
  is_flash_attn_greater_or_equal_2_10,
52
  logging,
53
  replace_return_docstrings,
 
57
  import torch.distributed as dist
58
  import numpy as np
59
 
60
+ try:
61
  from flash_attn import flash_attn_func, flash_attn_varlen_func
62
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
63
+ except ImportError:
64
+ pass
65
 
66
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
67
  # It means that the function will not be traced through and simply appear as a node in the graph.