|
|
import importlib.metadata |
|
|
|
|
|
from packaging import version |
|
|
from torch import nn |
|
|
from transformers import Qwen3Config, Qwen3Model, Qwen3PreTrainedModel |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.models.qwen3.modeling_qwen3 import ( |
|
|
Qwen3Attention, |
|
|
Qwen3DecoderLayer, |
|
|
Qwen3MLP, |
|
|
Qwen3RMSNorm, |
|
|
Qwen3RotaryEmbedding, |
|
|
) |
|
|
from transformers.utils import logging |
|
|
from transformers.utils.import_utils import _is_package_available |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def is_transformers_attn_greater_or_equal_4_56_2(): |
|
|
if not _is_package_available("transformers"): |
|
|
return False |
|
|
|
|
|
return version.parse(importlib.metadata.version("transformers")) >= version.parse( |
|
|
"4.56.2" |
|
|
) |
|
|
|
|
|
|
|
|
class ModifiedQwen3Attention(Qwen3Attention): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.is_causal = False |
|
|
|
|
|
|
|
|
class ModifiedQwen3DecoderLayer(Qwen3DecoderLayer): |
|
|
def __init__(self, config: Qwen3Config, layer_idx: int): |
|
|
GradientCheckpointingLayer.__init__(self) |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.self_attn = ModifiedQwen3Attention(config=config, layer_idx=layer_idx) |
|
|
|
|
|
self.mlp = Qwen3MLP(config) |
|
|
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = Qwen3RMSNorm( |
|
|
config.hidden_size, eps=config.rms_norm_eps |
|
|
) |
|
|
self.attention_type = config.layer_types[layer_idx] |
|
|
|
|
|
|
|
|
class Qwen3EncoderModel(Qwen3Model): |
|
|
_no_split_modules = ["ModifiedQwen3DecoderLayer"] |
|
|
|
|
|
def __init__(self, config: Qwen3Config): |
|
|
if not is_transformers_attn_greater_or_equal_4_56_2(): |
|
|
raise ValueError( |
|
|
"The current implementation of Qwen2EncoderModel follows modeling_qwen2.py of transformers version >= 4.56.2" |
|
|
) |
|
|
Qwen3PreTrainedModel.__init__(self, config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embed_tokens = nn.Embedding( |
|
|
config.vocab_size, config.hidden_size, self.padding_idx |
|
|
) |
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
ModifiedQwen3DecoderLayer(config, layer_idx) |
|
|
for layer_idx in range(config.num_hidden_layers) |
|
|
] |
|
|
) |
|
|
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
|
|
self.gradient_checkpointing = False |
|
|
self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|