Try to subclass PretrainedModel
Browse files- modeling_bert.py +1 -2
modeling_bert.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22 |
import torch.nn as nn
|
23 |
import torch.nn.functional as F
|
24 |
from einops import rearrange
|
25 |
-
from transformers import PretrainedModel
|
26 |
from .configuration_bert import JinaBertConfig
|
27 |
from transformers.models.bert.modeling_bert import (
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
@@ -39,7 +39,6 @@ from flash_attn.modules.block import Block
|
|
39 |
from flash_attn.modules.embedding import BertEmbeddings
|
40 |
from flash_attn.modules.mha import MHA
|
41 |
from flash_attn.modules.mlp import FusedMLP, Mlp
|
42 |
-
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
22 |
import torch.nn as nn
|
23 |
import torch.nn.functional as F
|
24 |
from einops import rearrange
|
25 |
+
from transformers.modeling_utils import PretrainedModel
|
26 |
from .configuration_bert import JinaBertConfig
|
27 |
from transformers.models.bert.modeling_bert import (
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
39 |
from flash_attn.modules.embedding import BertEmbeddings
|
40 |
from flash_attn.modules.mha import MHA
|
41 |
from flash_attn.modules.mlp import FusedMLP, Mlp
|
|
|
42 |
|
43 |
try:
|
44 |
from flash_attn.ops.fused_dense import FusedDense
|