Spaces:
Runtime error
Runtime error
# Copyright (c) Alibaba Cloud. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import importlib | |
import math | |
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from torch.cuda.amp import autocast | |
from torch.nn import CrossEntropyLoss | |
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList | |
from transformers.generation.logits_process import LogitsProcessorList | |
if TYPE_CHECKING: | |
from transformers.generation.streamers import BaseStreamer | |
from transformers.generation.utils import GenerateOutput | |
from transformers.modeling_outputs import ( | |
BaseModelOutputWithPast, | |
CausalLMOutputWithPast, | |
) | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.utils import logging | |
try: | |
from einops import rearrange | |
except ImportError: | |
rearrange = None | |
from torch import nn | |
SUPPORT_CUDA = torch.cuda.is_available() | |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() | |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 | |
from .configuration_qwen import QWenConfig | |
from .qwen_generation_utils import ( | |
HistoryType, | |
make_context, | |
decode_tokens, | |
get_stop_words_ids, | |
StopWordsLogitsProcessor, | |
) | |
from .visual import VisionTransformer | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "qwen" | |
_CONFIG_FOR_DOC = "QWenConfig" | |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] | |
_ERROR_BAD_CHAT_FORMAT = """\ | |
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". | |
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). | |
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 | |
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 | |
""" | |
_SENTINEL = object() | |
_ERROR_STREAM_IN_CHAT = """\ | |
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). | |
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 | |
""" | |
apply_rotary_emb_func = None | |
rms_norm = None | |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask | |
def _make_causal_mask( | |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 | |
): | |
""" | |
Make causal mask used for bi-directional self-attention. | |
""" | |
bsz, tgt_len = input_ids_shape | |
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | |
mask_cond = torch.arange(mask.size(-1), device=device) | |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
mask = mask.to(dtype) | |
if past_key_values_length > 0: | |
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) | |
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) | |
# Copied from transformers.models.bart.modeling_bart._expand_mask | |
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) | |
class QWenAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) | |
self.seq_length = config.seq_length | |
self.hidden_size = config.hidden_size | |
self.split_size = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.head_dim = self.hidden_size // self.num_heads | |
self.scale_attn_weights = True | |
self.projection_size = config.kv_channels * config.num_attention_heads | |
assert self.projection_size % config.num_attention_heads == 0 | |
self.hidden_size_per_attention_head = ( | |
self.projection_size // config.num_attention_heads | |
) | |
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) | |
self.c_proj = nn.Linear( | |
config.hidden_size, self.projection_size, bias=not config.no_bias | |
) | |
self.is_fp32 = not (config.bf16 or config.fp16) | |
self.bf16 = config.bf16 | |
self.use_dynamic_ntk = config.use_dynamic_ntk | |
self.use_logn_attn = config.use_logn_attn | |
logn_list = [ | |
math.log(i, self.seq_length) if i > self.seq_length else 1 | |
for i in range(1, 32768) | |
] | |
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] | |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob) | |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): | |
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | |
if self.scale_attn_weights: | |
attn_weights = attn_weights / torch.full( | |
[], | |
value.size(-1) ** 0.5, | |
dtype=attn_weights.dtype, | |
device=attn_weights.device, | |
) | |
query_length, key_length = query.size(-2), key.size(-2) | |
# causal_mask = self.bias[ | |
# :, :, key_length - query_length : key_length, :key_length | |
# ] | |
# mask_value = torch.finfo(attn_weights.dtype).min | |
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( | |
# attn_weights.device | |
# ) | |
# attn_weights = torch.where( | |
# causal_mask, attn_weights.to(attn_weights.dtype), mask_value | |
# ) | |
attn_weights = attn_weights + attention_mask | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
attn_weights = attn_weights.type(value.dtype) | |
attn_weights = self.attn_dropout(attn_weights) | |
if head_mask is not None: | |
attn_weights = attn_weights * head_mask | |
attn_output = torch.matmul(attn_weights, value) | |
attn_output = attn_output.transpose(1, 2) | |
return attn_output, attn_weights | |
def _upcast_and_reordered_attn( | |
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None | |
): | |
bsz, num_heads, q_seq_len, dk = query.size() | |
_, _, k_seq_len, _ = key.size() | |
attn_weights = torch.empty( | |
bsz * num_heads, | |
q_seq_len, | |
k_seq_len, | |
dtype=torch.float32, | |
device=query.device, | |
) | |
scale_factor = 1.0 | |
if self.scale_attn_weights: | |
scale_factor /= float(value.size(-1)) ** 0.5 | |
with autocast(enabled=False): | |
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( | |
-1, dk, k_seq_len | |
) | |
attn_weights = torch.baddbmm( | |
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor | |
) | |
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) | |
query_length, key_length = query.size(-2), key.size(-2) | |
causal_mask = registered_causal_mask[ | |
:, :, key_length - query_length : key_length, :key_length | |
] | |
mask_value = torch.finfo(attn_weights.dtype).min | |
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( | |
attn_weights.device | |
) | |
attn_weights = torch.where(causal_mask, attn_weights, mask_value) | |
if attention_mask is not None: | |
attn_weights = attn_weights + attention_mask | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
if attn_weights.dtype != torch.float32: | |
raise RuntimeError( | |
"Error with upcasting, attn_weights does not have dtype torch.float32" | |
) | |
attn_weights = attn_weights.type(value.dtype) | |
attn_weights = self.attn_dropout(attn_weights) | |
if head_mask is not None: | |
attn_weights = attn_weights * head_mask | |
attn_output = torch.matmul(attn_weights, value) | |
return attn_output, attn_weights | |
def _split_heads(self, tensor, num_heads, attn_head_size): | |
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | |
tensor = tensor.view(new_shape) | |
return tensor | |
def _merge_heads(self, tensor, num_heads, attn_head_size): | |
tensor = tensor.contiguous() | |
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) | |
return tensor.view(new_shape) | |
def forward( | |
self, | |
hidden_states: Optional[Tuple[torch.FloatTensor]], | |
rotary_pos_emb: Optional[List[torch.Tensor]] = None, | |
registered_causal_mask: Optional[torch.Tensor] = None, | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
): | |
mixed_x_layer = self.c_attn(hidden_states) | |
query, key, value = mixed_x_layer.split(self.split_size, dim=2) | |
query = self._split_heads(query, self.num_heads, self.head_dim) | |
key = self._split_heads(key, self.num_heads, self.head_dim) | |
value = self._split_heads(value, self.num_heads, self.head_dim) | |
if rotary_pos_emb is not None: | |
cur_len = query.shape[1] | |
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] | |
rotary_pos_emb = (rotary_pos_emb,) * 2 | |
q_pos_emb, k_pos_emb = rotary_pos_emb | |
# Slice the pos emb for current inference | |
query = apply_rotary_pos_emb(query, q_pos_emb) | |
key = apply_rotary_pos_emb(key, k_pos_emb) | |
if layer_past is not None: | |
past_key, past_value = layer_past[0], layer_past[1] | |
key = torch.cat((past_key, key), dim=1) | |
value = torch.cat((past_value, value), dim=1) | |
if use_cache: | |
present = (key, value) | |
else: | |
present = None | |
if self.use_logn_attn and not self.training: | |
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: | |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) | |
seq_start = key.size(1) - query.size(1) | |
seq_end = key.size(1) | |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] | |
query = query * logn_tensor.expand_as(query) | |
query = query.permute(0, 2, 1, 3) | |
key = key.permute(0, 2, 1, 3) | |
value = value.permute(0, 2, 1, 3) | |
attn_output, attn_weight = self._attn( | |
query, key, value, registered_causal_mask, attention_mask, head_mask | |
) | |
context_layer = self._merge_heads( | |
attn_output, self.num_heads, self.head_dim | |
) | |
attn_output = self.c_proj(context_layer) | |
outputs = (attn_output, present) | |
if output_attentions: | |
outputs += (attn_weight,) | |
return outputs | |
class QWenMLP(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.w1 = nn.Linear( | |
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias | |
) | |
self.w2 = nn.Linear( | |
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias | |
) | |
ff_dim_in = config.intermediate_size // 2 | |
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) | |
def forward(self, hidden_states): | |
a1 = self.w1(hidden_states) | |
a2 = self.w2(hidden_states) | |
intermediate_parallel = a1 * F.silu(a2) | |
output = self.c_proj(intermediate_parallel) | |
return output | |
class QWenBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
hidden_size = config.hidden_size | |
self.bf16 = config.bf16 | |
self.ln_1 = RMSNorm( | |
hidden_size, | |
eps=config.layer_norm_epsilon, | |
) | |
self.attn = QWenAttention(config) | |
self.ln_2 = RMSNorm( | |
hidden_size, | |
eps=config.layer_norm_epsilon, | |
) | |
self.mlp = QWenMLP(config) | |
def forward( | |
self, | |
hidden_states: Optional[Tuple[torch.FloatTensor]], | |
rotary_pos_emb: Optional[List[torch.Tensor]] = None, | |
registered_causal_mask: Optional[torch.Tensor] = None, | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = False, | |
output_attentions: Optional[bool] = False, | |
): | |
layernorm_output = self.ln_1(hidden_states) | |
attn_outputs = self.attn( | |
layernorm_output, | |
rotary_pos_emb, | |
registered_causal_mask=registered_causal_mask, | |
layer_past=layer_past, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
) | |
attn_output = attn_outputs[0] | |
outputs = attn_outputs[1:] | |
residual = hidden_states | |
layernorm_input = attn_output + residual | |
layernorm_output = self.ln_2(layernorm_input) | |
residual = layernorm_input | |
mlp_output = self.mlp(layernorm_output) | |
hidden_states = residual + mlp_output | |
if use_cache: | |
outputs = (hidden_states,) + outputs | |
else: | |
outputs = (hidden_states,) + outputs[1:] | |
return outputs | |
class QWenPreTrainedModel(PreTrainedModel): | |
config_class = QWenConfig | |
base_model_prefix = "transformer" | |
is_parallelizable = False | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["QWenBlock"] | |
def __init__(self, *inputs, **kwargs): | |
super().__init__(*inputs, **kwargs) | |
def _init_weights(self, module): | |
"""Initialize the weights.""" | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, RMSNorm): | |
module.weight.data.fill_(1.0) | |
for name, p in module.named_parameters(): | |
if name == "c_proj.weight": | |
p.data.normal_( | |
mean=0.0, | |
std=( | |
self.config.initializer_range | |
/ math.sqrt(2 * self.config.num_hidden_layers) | |
), | |
) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, QWenModel): | |
module.gradient_checkpointing = value | |
class QWenModel(QWenPreTrainedModel): | |
_keys_to_ignore_on_load_missing = ["attn.masked_bias"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.vocab_size = config.vocab_size | |
self.num_hidden_layers = config.num_hidden_layers | |
self.embed_dim = config.hidden_size | |
self.gradient_checkpointing = False | |
self.use_dynamic_ntk = config.use_dynamic_ntk | |
self.seq_length = config.seq_length | |
self.wte = nn.Embedding(self.vocab_size, self.embed_dim) | |
self.drop = nn.Dropout(config.emb_dropout_prob) | |
if config.rotary_pct == 1.0: | |
self.rotary_ndims = None | |
else: | |
assert config.rotary_pct < 1 | |
self.rotary_ndims = int( | |
config.kv_channels * config.rotary_pct | |
) | |
dim = ( | |
self.rotary_ndims | |
if self.rotary_ndims is not None | |
else config.kv_channels | |
) | |
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) | |
self.use_flash_attn = config.use_flash_attn | |
self.is_fp32 = not (config.bf16 or config.fp16) | |
self.registered_causal_mask = None | |
# if ( | |
# self.use_flash_attn | |
# and flash_attn_unpadded_func is not None | |
# and not self.is_fp32 | |
# ): | |
# self.registered_causal_mask = None | |
# else: | |
# max_positions = config.max_position_embeddings | |
# self.register_buffer( | |
# "registered_causal_mask", | |
# torch.tril( | |
# torch.ones((max_positions, max_positions), dtype=torch.bool) | |
# ).view(1, 1, max_positions, max_positions), | |
# persistent=False, | |
# ) | |
self.h = nn.ModuleList( | |
[ | |
QWenBlock( | |
config | |
) | |
for i in range(config.num_hidden_layers) | |
] | |
) | |
self.ln_f = RMSNorm( | |
self.embed_dim, | |
eps=config.layer_norm_epsilon, | |
) | |
self.visual = VisionTransformer(**config.visual) | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.wte | |
def set_input_embeddings(self, new_embeddings): | |
self.wte = new_embeddings | |
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask | |
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): | |
# create causal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
if input_shape[-1] > 1: | |
combined_attention_mask = _make_causal_mask( | |
input_shape, | |
inputs_embeds.dtype, | |
device=inputs_embeds.device, | |
past_key_values_length=past_key_values_length, | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( | |
inputs_embeds.device | |
) | |
combined_attention_mask = ( | |
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
images=None | |
): | |
if images is None: | |
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']): | |
bos_pos = torch.where(input_ids == self.config.visual['image_start_id']) | |
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) | |
assert (bos_pos[0] == eos_pos[0]).all() | |
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) | |
images = [] | |
for i, a, b in img_pos: | |
image = input_ids[i][a + 1 : b - 1].tolist() | |
image = image[ : image.index(self.config.visual['image_start_id'] + 2)] | |
images.append(bytes(image).decode('utf-8')) | |
images = self.visual.encode(images) | |
assert images.shape[0] == len(images) | |
else: | |
images = None | |
else: | |
bos_pos = torch.where(input_ids == self.config.visual['image_start_id']) | |
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) | |
assert (bos_pos[0] == eos_pos[0]).all() | |
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError( | |
"You cannot specify both input_ids and inputs_embeds at the same time" | |
) | |
elif input_ids is not None: | |
input_shape = input_ids.size() | |
input_ids = input_ids.view(-1, input_shape[-1]) | |
batch_size = input_ids.shape[0] | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
batch_size = inputs_embeds.shape[0] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
if token_type_ids is not None: | |
token_type_ids = token_type_ids.view(-1, input_shape[-1]) | |
if position_ids is not None: | |
position_ids = position_ids.view(-1, input_shape[-1]) | |
if past_key_values is None: | |
past_length = 0 | |
past_key_values = tuple([None] * len(self.h)) | |
else: | |
past_length = past_key_values[0][0].size(-2) | |
if position_ids is None: | |
position_ids = torch.arange( | |
past_length, | |
input_shape[-1] + past_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) | |
encoder_attention_mask = None | |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | |
if inputs_embeds is None: | |
inputs_embeds = self.wte(input_ids) | |
if batch_size <= 0: | |
raise ValueError("batch_size has to be defined and > 0") | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, input_shape, inputs_embeds, past_length | |
) | |
hidden_states = inputs_embeds | |
kv_seq_len = hidden_states.size()[1] | |
if past_key_values[0] is not None: | |
# past key values[0][0] shape: bs * seq_len * head_num * dim | |
kv_seq_len += past_key_values[0][0].shape[1] | |
if ( | |
self.use_dynamic_ntk | |
and kv_seq_len == hidden_states.size()[1] | |
and not self.training | |
): | |
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 | |
ntk_alpha = 2 ** math.ceil(context_value) - 1 | |
ntk_alpha = max(ntk_alpha, 1) | |
else: | |
ntk_alpha = self.rotary_emb._ntk_alpha_cached | |
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) | |
for idx in range(len(rotary_pos_emb)): | |
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device) | |
hidden_states = self.drop(hidden_states) | |
if images is not None: | |
for idx, (i, a, b) in enumerate(img_pos): | |
hidden_states[i][a + 1 : b] = images[idx] | |
output_shape = input_shape + (hidden_states.size(-1),) | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
logger.warning_once( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
) | |
use_cache = False | |
presents = () if use_cache else None | |
all_self_attentions = () if output_attentions else None | |
all_hidden_states = () if output_hidden_states else None | |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, use_cache, output_attentions) | |
return custom_forward | |
outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
rotary_pos_emb, | |
self.registered_causal_mask, | |
None, | |
attention_mask, | |
head_mask[i], | |
encoder_hidden_states, | |
encoder_attention_mask, | |
) | |
else: | |
outputs = block( | |
hidden_states, | |
layer_past=layer_past, | |
rotary_pos_emb=rotary_pos_emb, | |
registered_causal_mask=self.registered_causal_mask, | |
attention_mask=attention_mask, | |
head_mask=head_mask[i], | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
) | |
hidden_states = outputs[0] | |
if use_cache is True: | |
presents = presents + (outputs[1],) | |
if output_attentions: | |
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) | |
hidden_states = self.ln_f(hidden_states) | |
hidden_states = hidden_states.view(output_shape) | |
# Add last hidden state | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if not return_dict: | |
return tuple( | |
v for v in [hidden_states, presents, all_hidden_states] if v is not None | |
) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=presents, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attentions, | |
) | |
class QWenLMHeadModel(QWenPreTrainedModel): | |
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] | |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] | |
def __init__(self, config): | |
super().__init__(config) | |
assert ( | |
config.bf16 + config.fp16 + config.fp32 <= 1 | |
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" | |
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 | |
if autoset_precision: | |
if SUPPORT_BF16: | |
logger.warn( | |
"The model is automatically converting to bf16 for faster inference. " | |
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." | |
) | |
config.bf16 = True | |
elif SUPPORT_FP16: | |
logger.warn( | |
"The model is automatically converting to fp16 for faster inference. " | |
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." | |
) | |
config.fp16 = True | |
else: | |
config.fp32 = True | |
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: | |
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") | |
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: | |
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") | |
if config.fp32: | |
if SUPPORT_BF16: | |
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") | |
elif SUPPORT_FP16: | |
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") | |
self.transformer = QWenModel(config) | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
if config.bf16: | |
self.transformer.bfloat16() | |
self.lm_head.bfloat16() | |
if config.fp16: | |
self.transformer.half() | |
self.lm_head.half() | |
self.post_init() | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def prepare_inputs_for_generation( | |
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs | |
): | |
token_type_ids = kwargs.get("token_type_ids", None) | |
if past_key_values: | |
input_ids = input_ids[:, -1].unsqueeze(-1) | |
if token_type_ids is not None: | |
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |
attention_mask = kwargs.get("attention_mask", None) | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -1].unsqueeze(-1) | |
else: | |
position_ids = None | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"position_ids": position_ids, | |
"attention_mask": attention_mask, | |
"token_type_ids": token_type_ids, | |
} | |
) | |
return model_inputs | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
transformer_outputs = self.transformer( | |
input_ids, | |
past_key_values=past_key_values, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
lm_logits = self.lm_head(hidden_states) | |
loss = None | |
if labels is not None: | |
labels = labels.to(lm_logits.device) | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | |
) | |
if not return_dict: | |
output = (lm_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=lm_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) | |
def _reorder_cache( | |
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor | |
) -> Tuple[Tuple[torch.Tensor]]: | |
return tuple( | |
tuple( | |
past_state.index_select(0, beam_idx.to(past_state.device)) | |
for past_state in layer_past | |
) | |
for layer_past in past_key_values | |
) | |
def chat( | |
self, | |
tokenizer: PreTrainedTokenizer, | |
query: str, | |
history: Optional[HistoryType], | |
system: str = "You are a helpful assistant.", | |
append_history: bool = True, | |
stream: Optional[bool] = _SENTINEL, | |
stop_words_ids: Optional[List[List[int]]] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
**kwargs, | |
) -> Tuple[str, HistoryType]: | |
generation_config = generation_config if generation_config is not None else self.generation_config | |
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT | |
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT | |
if history is None: | |
history = [] | |
if stop_words_ids is None: | |
stop_words_ids = [] | |
max_window_size = kwargs.get('max_window_size', None) | |
if max_window_size is None: | |
max_window_size = generation_config.max_window_size | |
raw_text, context_tokens = make_context( | |
tokenizer, | |
query, | |
history=history, | |
system=system, | |
max_window_size=max_window_size, | |
chat_format=generation_config.chat_format, | |
) | |
stop_words_ids.extend(get_stop_words_ids( | |
generation_config.chat_format, tokenizer | |
)) | |
input_ids = torch.tensor([context_tokens]).to(self.device) | |
outputs = self.generate( | |
input_ids, | |
stop_words_ids=stop_words_ids, | |
return_dict_in_generate=False, | |
generation_config=generation_config, | |
**kwargs, | |
) | |
response = decode_tokens( | |
outputs[0], | |
tokenizer, | |
raw_text_len=len(raw_text), | |
context_length=len(context_tokens), | |
chat_format=generation_config.chat_format, | |
verbose=False, | |
errors='replace' | |
) | |
if append_history: | |
history.append((query, response)) | |
return response, history | |
def chat_pretrain( | |
self, | |
tokenizer: PreTrainedTokenizer, | |
query: str, | |
history: Optional[HistoryType], | |
system: str = "You are a helpful assistant.", | |
append_history: bool = False, | |
stream: Optional[bool] = _SENTINEL, | |
stop_words_ids: Optional[List[List[int]]] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
**kwargs, | |
) -> Tuple[str, HistoryType]: | |
generation_config = generation_config if generation_config is not None else self.generation_config | |
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT | |
if history is None: | |
history = [] | |
if stop_words_ids is None: | |
stop_words_ids = [] | |
max_window_size = kwargs.get('max_window_size', None) | |
if max_window_size is None: | |
max_window_size = generation_config.max_window_size | |
raw_text, context_tokens = make_context( | |
tokenizer, | |
query, | |
history=history, | |
system=system, | |
max_window_size=max_window_size, | |
chat_format=generation_config.chat_format, | |
) | |
stop_words_ids.extend(get_stop_words_ids( | |
generation_config.chat_format, tokenizer | |
)) | |
input_ids = torch.tensor([context_tokens]).to(self.device) | |
outputs = self.generate( | |
input_ids, | |
stop_words_ids=stop_words_ids, | |
return_dict_in_generate=False, | |
generation_config=generation_config, | |
**kwargs, | |
) | |
response = decode_tokens( | |
outputs[0], | |
tokenizer, | |
raw_text_len=len(raw_text), | |
context_length=len(context_tokens), | |
chat_format=generation_config.chat_format, | |
verbose=False, | |
errors='replace' | |
) | |
if append_history: | |
history.append((query, response)) | |
return response, history | |
def chat_stream( | |
self, | |
tokenizer: PreTrainedTokenizer, | |
query: str, | |
history: Optional[HistoryType], | |
system: str = "You are a helpful assistant.", | |
stop_words_ids: Optional[List[List[int]]] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
**kwargs, | |
) -> Generator[str, Any, None]: | |
generation_config = generation_config if generation_config is not None else self.generation_config | |
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT | |
if history is None: | |
history = [] | |
if stop_words_ids is None: | |
stop_words_ids = [] | |
max_window_size = kwargs.get('max_window_size', None) | |
if max_window_size is None: | |
max_window_size = generation_config.max_window_size | |
raw_text, context_tokens = make_context( | |
tokenizer, | |
query, | |
history=history, | |
system=system, | |
max_window_size=max_window_size, | |
chat_format=generation_config.chat_format, | |
) | |
stop_words_ids.extend(get_stop_words_ids( | |
generation_config.chat_format, tokenizer | |
)) | |
if stop_words_ids is not None: | |
stop_words_logits_processor = StopWordsLogitsProcessor( | |
stop_words_ids=stop_words_ids, | |
eos_token_id=generation_config.eos_token_id, | |
) | |
if logits_processor is None: | |
logits_processor = LogitsProcessorList([stop_words_logits_processor]) | |
else: | |
logits_processor.append(stop_words_logits_processor) | |
input_ids = torch.tensor([context_tokens]).to(self.device) | |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig | |
self.__class__.generate_stream = NewGenerationMixin.generate | |
self.__class__.sample_stream = NewGenerationMixin.sample_stream | |
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) | |
def stream_generator(): | |
outputs = [] | |
for token in self.generate_stream( | |
input_ids, | |
return_dict_in_generate=False, | |
generation_config=stream_config, | |
logits_processor=logits_processor, | |
seed=-1, | |
**kwargs): | |
outputs.append(token.item()) | |
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') | |
return stream_generator() | |
def generate( | |
self, | |
inputs: Optional[torch.Tensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
prefix_allowed_tokens_fn: Optional[ | |
Callable[[int, torch.Tensor], List[int]] | |
] = None, | |
synced_gpus: Optional[bool] = None, | |
assistant_model: Optional["PreTrainedModel"] = None, | |
streamer: Optional["BaseStreamer"] = None, | |
**kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
generation_config = generation_config if generation_config is not None else self.generation_config | |
# Process stop_words_ids. | |
stop_words_ids = kwargs.pop("stop_words_ids", None) | |
if stop_words_ids is None and generation_config is not None: | |
stop_words_ids = getattr(generation_config, "stop_words_ids", None) | |
if stop_words_ids is None: | |
stop_words_ids = getattr(generation_config, "stop_words_ids", None) | |
if stop_words_ids is not None: | |
stop_words_logits_processor = StopWordsLogitsProcessor( | |
stop_words_ids=stop_words_ids, | |
eos_token_id=generation_config.eos_token_id, | |
) | |
if logits_processor is None: | |
logits_processor = LogitsProcessorList([stop_words_logits_processor]) | |
else: | |
logits_processor.append(stop_words_logits_processor) | |
return super().generate( | |
inputs, | |
generation_config=generation_config, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
synced_gpus=synced_gpus, | |
assistant_model=assistant_model, | |
streamer=streamer, | |
**kwargs, | |
) | |
class RotaryEmbedding(torch.nn.Module): | |
def __init__(self, dim, base=10000): | |
super().__init__() | |
self.dim = dim | |
self.base = base | |
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
if importlib.util.find_spec("einops") is None: | |
raise RuntimeError("einops is required for Rotary Embedding") | |
self._rotary_pos_emb_cache = None | |
self._seq_len_cached = 0 | |
self._ntk_alpha_cached = 1.0 | |
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): | |
seqlen = max_seq_len + offset | |
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: | |
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) | |
self.inv_freq = 1.0 / ( | |
base | |
** ( | |
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() | |
/ self.dim | |
) | |
) | |
self._seq_len_cached = max(2 * seqlen, 16) | |
self._ntk_alpha_cached = ntk_alpha | |
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) | |
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
from einops import rearrange | |
emb = rearrange(emb, "n d -> 1 n 1 d") | |
cos, sin = emb.cos(), emb.sin() | |
self._rotary_pos_emb_cache = [cos, sin] | |
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): | |
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) | |
cos, sin = self._rotary_pos_emb_cache | |
return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] | |
def _rotate_half(x): | |
from einops import rearrange | |
x = rearrange(x, "... (j d) -> ... j d", j=2) | |
x1, x2 = x.unbind(dim=-2) | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(t, freqs): | |
cos, sin = freqs | |
if apply_rotary_emb_func is not None and t.is_cuda: | |
t_ = t.float() | |
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] | |
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] | |
output = apply_rotary_emb_func(t_, cos, sin).type_as(t) | |
return output | |
else: | |
rot_dim = freqs[0].shape[-1] | |
cos, sin = freqs | |
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] | |
t_ = t_.float() | |
t_pass_ = t_pass_.float() | |
t_ = (t_ * cos) + (_rotate_half(t_) * sin) | |
return torch.cat((t_, t_pass_), dim=-1).type_as(t) | |
class RMSNorm(torch.nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
if rms_norm is not None and x.is_cuda: | |
return rms_norm(x, self.weight, self.eps) | |
else: | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |