|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch RWKV6Qwen2 model.""" |
|
|
|
import math |
|
import inspect |
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.cache_utils import Cache, StaticCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_code_sample_docstrings, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from .configuration_rwkv6qwen2 import RWKV6Qwen2Config |
|
|
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B" |
|
_CONFIG_FOR_DOC = "RWKV6Qwen2Config" |
|
|
|
class RWKV6State(Cache): |
|
def __init__(self) -> None: |
|
self._seen_tokens = 0 |
|
self.layer_kv_states: List[torch.Tensor] = [] |
|
self.layer_shift_states: List[torch.Tensor] = [] |
|
|
|
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.layer_kv_states) |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
return new_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
raise NotImplementedError('Cannot reorder Linear Attention state') |
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
return self._seen_tokens |
|
|
|
def get_max_cache_shape(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
""" |
|
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length. |
|
""" |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def crop(self, max_length: int): |
|
|
|
return |
|
|
|
@torch.no_grad |
|
def update( |
|
self, |
|
kv_state: torch.Tensor, |
|
shift_state: torch.Tensor, |
|
token_count: int, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += token_count |
|
|
|
|
|
|
|
for _ in range(len(self.layer_kv_states), layer_idx + 1): |
|
self.layer_kv_states.append(torch.zeros_like(kv_state).requires_grad_(False)) |
|
self.layer_shift_states.append(torch.zeros_like(shift_state).requires_grad_(False)) |
|
self.layer_kv_states[layer_idx].copy_(kv_state) |
|
self.layer_shift_states[layer_idx].copy_(shift_state) |
|
|
|
return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fla.ops.gla.chunk import chunk_gla |
|
from fla.ops.gla.fused_recurrent import fused_recurrent_gla |
|
|
|
class RWKV6Attention(nn.Module): |
|
def __init__(self, config, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
if layer_idx is None: |
|
logger.warning_once( |
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
|
"when creating this class." |
|
) |
|
|
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads) |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.is_causal = True |
|
self.attention_dropout = config.attention_dropout |
|
|
|
if self.hidden_size % self.num_heads != 0: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias)) |
|
|
|
self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
nn.init.zeros_(self.gate.weight) |
|
|
|
n_layer = self.config.num_hidden_layers |
|
n_embd = self.hidden_size |
|
dim_att = self.num_heads * self.head_dim |
|
layer_id = self.layer_idx |
|
|
|
with torch.no_grad(): |
|
ratio_0_to_1 = layer_id / (n_layer - 1) |
|
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) |
|
ddd = torch.ones(1, 1, n_embd) |
|
for i in range(n_embd): |
|
ddd[0, 0, i] = i / n_embd |
|
|
|
ddd = torch.zeros(1, 1, n_embd) |
|
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
|
self.time_maa_r = nn.Parameter(torch.zeros_like(ddd)) |
|
self.time_maa_k = nn.Parameter(torch.zeros_like(ddd)) |
|
self.time_maa_v = nn.Parameter(torch.zeros_like(ddd)) |
|
self.time_maa_w = nn.Parameter(torch.zeros_like(ddd)) |
|
self.time_maa_g = nn.Parameter(torch.zeros_like(ddd)) |
|
|
|
D_MIX_LORA = 32 if n_embd < 4096 else 64 |
|
self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, n_embd).uniform_(-0.01, 0.01)) |
|
self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, D_MIX_LORA*self.time_maa_w2.size(0))) |
|
|
|
|
|
decay_speed = torch.ones(dim_att) |
|
for n in range(dim_att): |
|
decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
|
self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att)) |
|
D_DECAY_LORA = 64 if n_embd < 4096 else 128 |
|
self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA)) |
|
self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, dim_att).uniform_(-0.01, 0.01)) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[RWKV6State] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
): |
|
output_shift_state = hidden_states[:, -1:].detach().clone() |
|
|
|
bsz, q_len, hidden_dim = hidden_states.size() |
|
H = self.num_heads |
|
|
|
x = hidden_states |
|
|
|
if use_cache and past_key_value is not None and len(past_key_value) > self.layer_idx: |
|
input_kv_state, input_shift_state = past_key_value[self.layer_idx] |
|
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1) |
|
else: |
|
input_kv_state = None |
|
xprev = F.pad(x, (0, 0, 1, -1)) |
|
|
|
dxprev = xprev - x |
|
|
|
xxx = x + dxprev * self.time_maa_x |
|
xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1) |
|
xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim) |
|
|
|
mr, mk, mv, mw, mg = xxx.unbind(dim=0) |
|
xr = x + dxprev * (self.time_maa_r + mr) |
|
xk = x + dxprev * (self.time_maa_k + mk) |
|
xv = x + dxprev * (self.time_maa_v + mv) |
|
xw = x + dxprev * (self.time_maa_w + mw) |
|
xg = x + dxprev * (self.time_maa_g + mg) |
|
|
|
query_states = self.q_proj(xr) |
|
key_states = self.k_proj(xk) |
|
value_states = self.v_proj(xv) |
|
decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype) |
|
gate_states = F.sigmoid(self.gate(xg)) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
|
decay_states_log = -decay_states.float().exp() |
|
|
|
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype) |
|
|
|
query_states = query_states.to(value_states.dtype) |
|
key_states = key_states.to(value_states.dtype) |
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
attn_weights = torch.empty(0, device=x.device) |
|
|
|
scale = query_states.shape[-1] ** -0.5 |
|
output_final_state = not self.training and use_cache and past_key_value is not None |
|
|
|
|
|
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state) |
|
|
|
if output_final_state: |
|
past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.view(bsz, q_len, -1) |
|
attn_output = self.o_proj(attn_output * gate_states) |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer): |
|
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int): |
|
nn.Module.__init__(self) |
|
self.hidden_size = config.hidden_size |
|
|
|
self.self_attn = RWKV6Attention(config, layer_idx) |
|
|
|
self.mlp = Qwen2MLP(config) |
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
RWKV6QWEN2_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`RWKV6Qwen2Config`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", |
|
RWKV6QWEN2_START_DOCSTRING, |
|
) |
|
class RWKV6Qwen2PreTrainedModel(PreTrainedModel): |
|
config_class = RWKV6Qwen2Config |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["RWKV6Qwen2DecoderLayer"] |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
RWKV6QWEN2_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
|
information on the default strategy. |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
|
|
Two formats are allowed: |
|
- a [`~cache_utils.Cache`] instance, see our |
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); |
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
|
cache format. |
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
|
legacy cache format will be returned. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
|
of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
|
the complete sequence length. |
|
""" |
|
|
|
@add_start_docstrings( |
|
"The bare RWKV6Qwen2 Model outputting raw hidden-states without any specific head on top.", |
|
RWKV6QWEN2_START_DOCSTRING, |
|
) |
|
class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] |
|
|
|
Args: |
|
config: RWKV6Qwen2Config |
|
""" |
|
|
|
def __init__(self, config: RWKV6Qwen2Config): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[RWKV6Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self._attn_implementation = config._attn_implementation |
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
|
|
if use_cache and not isinstance(past_key_values, RWKV6State): |
|
|
|
past_key_values = RWKV6State() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
causal_mask = None |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = None |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
causal_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
|
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = RWKV6Qwen2Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
num_logits_to_keep: int = 0, |
|
**loss_kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
num_logits_to_keep (`int`, *optional*): |
|
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, RWKV6Qwen2ForCausalLM |
|
|
|
>>> model = RWKV6Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[Cache] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or |
|
slicing inputs given the existing cache. |
|
|
|
See the forward pass in the model documentation for expected arguments (different models might have different |
|
requirements for e.g. `past_key_values`). This function should work as is for most LLMs. |
|
""" |
|
|
|
|
|
model_inputs = {} |
|
|
|
if self._supports_cache_class: |
|
model_inputs["cache_position"] = cache_position |
|
|
|
|
|
|
|
elif cache_position is None: |
|
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
model_inputs["past_key_values"] = past_key_values |
|
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: |
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
input_ids = input_ids[:, cache_position] |
|
|
|
|
|
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
|
|
|
if not self.config.is_encoder_decoder: |
|
if inputs_embeds is not None and cache_position[0] == 0: |
|
model_inputs[input_ids_key] = None |
|
model_inputs["inputs_embeds"] = inputs_embeds |
|
else: |
|
|
|
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) |
|
model_inputs["inputs_embeds"] = None |
|
else: |
|
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
if ( |
|
attention_mask is not None |
|
and kwargs.get("position_ids") is None |
|
and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) |
|
): |
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
kwargs["position_ids"] = position_ids |
|
|
|
|
|
for model_input_name in ["position_ids", "token_type_ids"]: |
|
model_input = kwargs.get(model_input_name) |
|
if model_input is not None: |
|
if past_key_values: |
|
model_input = model_input[:, -input_ids.shape[1] :] |
|
model_input = model_input.clone(memory_format=torch.contiguous_format) |
|
model_inputs[model_input_name] = model_input |
|
|
|
|
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: |
|
if model_inputs["inputs_embeds"] is not None: |
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
|
device = model_inputs["inputs_embeds"].device |
|
else: |
|
batch_size, sequence_length = model_inputs[input_ids_key].shape |
|
device = model_inputs[input_ids_key].device |
|
|
|
|
|
|
|
base_model = getattr(self, self.base_model_prefix, None) |
|
if base_model is None: |
|
causal_mask_creation_function = getattr( |
|
self, "_prepare_4d_causal_attention_mask_with_cache_position", None |
|
) |
|
else: |
|
causal_mask_creation_function = getattr( |
|
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None |
|
) |
|
if causal_mask_creation_function is None: |
|
logger.warning_once( |
|
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " |
|
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " |
|
"writing code, see Llama for an example implementation. If you're a user, please report this " |
|
"issue on GitHub." |
|
) |
|
else: |
|
attention_mask = causal_mask_creation_function( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=past_key_values.get_max_cache_shape(), |
|
dtype=self.dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=batch_size, |
|
config=self.config, |
|
past_key_values=past_key_values, |
|
) |
|
if attention_mask is not None: |
|
model_inputs["attention_mask"] = attention_mask |
|
|
|
|
|
for key, value in kwargs.items(): |
|
if key not in model_inputs: |
|
model_inputs[key] = value |
|
|
|
|
|
model_inputs.pop("labels", None) |
|
return model_inputs |
|
|
|
@add_start_docstrings( |
|
""" |
|
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer). |
|
|
|
[`RWKV6Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
|
(e.g. GPT-2) do. |
|
|
|
Since it does classification on the last token, it requires to know the position of the last token. If a |
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
|
each row of the batch). |
|
""", |
|
RWKV6QWEN2_START_DOCSTRING, |
|
) |
|
class RWKV6Qwen2ForSequenceClassification(RWKV6Qwen2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = RWKV6Qwen2Model(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
transformer_outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
logits = self.score(hidden_states) |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
if self.config.pad_token_id is None: |
|
sequence_lengths = -1 |
|
else: |
|
if input_ids is not None: |
|
|
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
|
sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
sequence_lengths = sequence_lengths.to(logits.device) |
|
else: |
|
sequence_lengths = -1 |
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(pooled_logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(pooled_logits, labels) |
|
if not return_dict: |
|
output = (pooled_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutputWithPast( |
|
loss=loss, |
|
logits=pooled_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
The RWKV6Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states |
|
output) e.g. for Named-Entity-Recognition (NER) tasks. |
|
""", |
|
RWKV6QWEN2_START_DOCSTRING, |
|
) |
|
|
|
class RWKV6Qwen2ForTokenClassification(RWKV6Qwen2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = RWKV6Qwen2Model(config) |
|
if getattr(config, "classifier_dropout", None) is not None: |
|
classifier_dropout = config.classifier_dropout |
|
elif getattr(config, "hidden_dropout", None) is not None: |
|
classifier_dropout = config.hidden_dropout |
|
else: |
|
classifier_dropout = 0.1 |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=TokenClassifierOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = outputs[0] |
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.score(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.config) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
The RWKV6Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like |
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). |
|
""", |
|
RWKV6QWEN2_START_DOCSTRING, |
|
) |
|
|
|
class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel): |
|
base_model_prefix = "model" |
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = RWKV6Qwen2Model(config) |
|
self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
start_positions: Optional[torch.LongTensor] = None, |
|
end_positions: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple, QuestionAnsweringModelOutput]: |
|
r""" |
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
loss = None |
|
if start_positions is not None and end_positions is not None: |
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) |
|
|
|
if not return_dict: |
|
output = (start_logits, end_logits) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return QuestionAnsweringModelOutput( |
|
loss=loss, |
|
start_logits=start_logits, |
|
end_logits=end_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|