zR commited on
Commit
6c2e473
1 Parent(s): fe11ac1
Files changed (1) hide show
  1. modeling_chatglm.py +9 -6
modeling_chatglm.py CHANGED
@@ -21,17 +21,20 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
- is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
- from .configuration_chatglm import ChatGLMConfig
30
  from .visual import EVA2CLIPModel
 
31
 
32
- if is_flash_attn_2_available():
33
- from flash_attn import flash_attn_func, flash_attn_varlen_func
34
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
35
 
36
  # flags required to enable jit fusion kernels
37
 
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
 
28
  from .visual import EVA2CLIPModel
29
+ from .configuration_chatglm import ChatGLMConfig
30
 
31
+ try:
32
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
33
+ if is_flash_attn_2_available():
34
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
35
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
36
+ except:
37
+ pass
38
 
39
  # flags required to enable jit fusion kernels
40