geb-1.3b / modeling_geb.py
luxq's picture
upload config, tokenizer and modeling file
93e390f verified
"""PyTorch GEB model."""
import math
import copy
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, List
import importlib.util
from torch.nn.utils import skip_init
import torch.nn.functional as F
import torch
import torch.utils.checkpoint
from torch import einsum, nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, LayerNorm, CrossEntropyLoss, MSELoss
from copy import deepcopy
from deepspeed.accelerator import get_accelerator
try:
from einops import rearrange
except ImportError:
rearrange = None
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_geblm import GEBConfig
try:
# FlashAttention-2
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder")
flash_attn_builder = None
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "geb"
_CONFIG_FOR_DOC = "GEBConfig"
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config: GEBConfig):
super().__init__()
self.prefix_projection = config.prefix_projection
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
kv_size = config.num_layers * config.kv_channels * self.num_key_value_groups * 2
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(kv_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, kv_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len,
config.num_layers * config.kv_channels * self.num_key_value_groups * 2)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
# class RotaryEmbedding(nn.Module):
# def __init__(self, dim, original_impl=False, device=None, dtype=None):
# super().__init__()
# inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
# self.register_buffer("inv_freq", inv_freq)
# self.dim = dim
# self.original_impl = original_impl
# def forward_impl(
# self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
# ):
# """Enhanced Transformer with Rotary Position Embedding.
# Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
# transformers/rope/__init__.py. MIT License:
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
# """
# # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
# theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
# # Create position indexes `[0, 1, ..., seq_len - 1]`
# seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
# # Calculate the product of position index and $\theta_i$
# idx_theta = torch.outer(seq_idx, theta).float()
# cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# # this is to mimic the behaviour of complex32, else we will get different results
# if dtype in (torch.float16, torch.bfloat16, torch.int8):
# cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
# return cache
# def forward(self, max_seq_len, offset=0):
# return self.forward_impl(
# max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
# )
# @torch.jit.script
# def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# # x: [sq, b, np, hn]
# sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
# rot_dim = rope_cache.shape[-2] * 2
# x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# # truncate to support variable sizes
# rope_cache = rope_cache[:sq]
# xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
# rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
# x_out2 = torch.stack(
# [
# xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
# xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
# ],
# -1,
# )
# x_out2 = x_out2.flatten(3)
# return torch.cat((x_out2, x_pass), dim=-1)
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
if importlib.util.find_spec('einops') is None:
raise RuntimeError("einops is required for Rotary Embedding")
def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
# Calculate the product of seq and inv_freq
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
from einops import rearrange
# print('rearrange:', rearrange(emb, 'n d -> n 1 1 d').size())
return rearrange(emb, 'n d -> n 1 1 d')
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
from einops import rearrange
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
# print('t:', t.size())
# print('freqs:', freqs.size())
rot_dim = freqs.shape[-1]
# print('rot_dim:', rot_dim)
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
# print(t.shape, t_pass.shape, freqs.shape)
t = (t * freqs.cos().to(t.dtype)) + (_rotate_half(t) * freqs.sin().to(t.dtype))
return torch.cat((t, t_pass), dim=-1)
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.eps = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight * hidden_states).to(input_dtype)
class MLP(torch.nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config: GEBConfig, device=None):
super(MLP, self).__init__()
self.add_bias = config.add_bias_linear #false
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = nn.Linear(
config.hidden_size,
config.ffn_hidden_size * 2, # config.ffn_hidden_size * 2
bias=self.add_bias,
device=device,
**_config_to_kwargs(config)
)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
# Project back to h.
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size,
config.hidden_size,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config)
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
return output
class CoreAttention(torch.nn.Module):
def __init__(self, config: GEBConfig, layer_number):
super(CoreAttention, self).__init__()
# self.fp16 = config.fp16
# self.bf16 = config.bf16
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.num_layers = config.num_layers
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.coeff = coeff
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk],Tensor to store matrix multiplication of query and key
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
if self.coeff is not None:
attention_scores = attention_scores * self.coeff
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
device=attention_scores.device, dtype=torch.bool)
attention_mask.tril_()
attention_mask = ~attention_mask
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, config: GEBConfig, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_varlen_func is not None or flash_attn_builder is not None, \
('Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.config = config
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_v2 is True
self.flash_attn_func = flash_attn_varlen_func if config.use_flash_attn else print('false to Use FlashAttention-2')
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
# print(i.dtype() for i in (q,k,v) )
# assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
# assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
# if get_accelerator().device_name() == 'cuda':
# assert all((i.is_cuda for i in (q,k,v)))
# else:
# assert all((i.is_xpu for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
if get_accelerator().device_name() == 'cuda':
# goes for cuda device
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
else:
# goes for other device
q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) if get_accelerator().device_name() == 'cuda' else None
dropout_p = 0
output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange(
output, 'b h s d -> b s h d').contiguous()
return output
class GEBAttention(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config: GEBConfig, layer_number, device=None):
super().__init__()
self.config = config
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads
self.use_flash_attn = config.use_flash_attn
# Per attention head and per partition values.
self.hidden_size_per_partition = self.projection_size
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
self.num_key_value_heads_per_partition = config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.kv_projection_size = config.kv_channels * config.num_key_value_heads
assert self.hidden_size_per_attention_head == self.kv_projection_size // config.num_key_value_heads
# self.max_position_embeddings = config.model_max_length
if self.use_flash_attn:
global flash_attn_builder
try:
flash_attn_builder = FlashAttentionBuilder().load()
except TypeError:
flash_attn_builder = None
assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
self.query = nn.Linear(config.hidden_size, self.projection_size,
bias=config.add_bias_linear,
device=device, **_config_to_kwargs(config)
)
self.key_value = nn.Linear(config.hidden_size, 2 * self.kv_projection_size,
bias=config.add_bias_linear,
device=device, **_config_to_kwargs(config)
)
if config.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(config, causal=True, attention_dropout=config.attention_dropout)
else:
self.core_attention = CoreAttention(config, self.layer_number)
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
device=device, **_config_to_kwargs(config)
)
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_key_value_groups,
self.hidden_size_per_attention_head,
dtype=dtype,
device=device)
def repeat_kv(self, hidden_states, n_rep):
slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(
slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)
def forward(self, hidden_states, attention_mask,
rotary_pos_emb=None, kv_cache=None, use_cache=True):
# Attention head [sq, b, h]--> [sq, b, hp]
query_layer = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(hidden_states)
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_key_value_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
(key_layer,
value_layer) = split_tensor_along_last_dim(
mixed_kv_layer, 2)
# Repeat kv
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
# if rotary_pos_emb is not None:
# query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
# key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
if self.use_flash_attn:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
output= self.dense(context_layer)# output, bias = self.dense(context_layer)
return output, kv_cache
class GEBBlock(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config: GEBConfig, layer_number, device=None):
super(GEBBlock, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= config.apply_residual_connection_post_layernorm
# self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
dtype=config.torch_dtype)
self.self_attention = GEBAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
dtype=config.torch_dtype)
self.mlp = MLP(config, device=device)
def forward(self, hidden_states, attention_mask=None,
rotary_pos_emb=None,
kv_cache=None,
use_cache=True):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = \
self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = torch.nn.functional.dropout(attention_output,
p=0.0,
training=self.training)
layernorm_input = residual + layernorm_input
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
out = torch.nn.functional.dropout(mlp_output,
p=0.0,
training=self.training)
output = residual + out
return output, kv_cache
class GEBTransformer(torch.nn.Module):
"""Transformer class."""
def __init__(self, config: GEBConfig, device=None):
super(GEBTransformer, self).__init__()
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = config.post_layer_norm
self.num_layers = config.num_layers
def build_layer(layer_number):
return GEBBlock(
config,
layer_number,
device=device)
# Build the layers
self.layers = []
for i in range(self.num_layers):
layer_num = i + 1
self.layers.append(build_layer(layer_num))
self.layers = torch.nn.ModuleList(self.layers)
if self.post_layer_norm:
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
dtype=config.torch_dtype)
self.gradient_checkpointing = False
def _get_layer(self, layer_number):
return self.layers[layer_number]
def forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
):
if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
for index in range(self.num_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer = self._get_layer(index)
if self.gradient_checkpointing and self.training:
layer_hidden = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache
)
else:
layer_hidden = layer(
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=kv_caches[index],
use_cache=use_cache
)
hidden_states, kv_cache = layer_hidden
if use_cache:
presents = presents + (kv_cache,)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, presents, all_hidden_states, all_self_attentions
class GEBPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable = False
supports_gradient_checkpointing = True
config_class = GEBConfig
base_model_prefix = "transformer"
_no_split_modules = ["GEBBlock"]
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
return
def get_masks(self, input_ids, past_key_values, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
if past_length:
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
device=input_ids.device), full_attention_mask), dim=-1)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
def get_position_ids(self, input_ids, device):
batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GEBTransformer):
module.gradient_checkpointing = value
class Embedding(torch.nn.Module):
"""Language model embeddings."""
def __init__(self, config: GEBConfig, device=None):
super(Embedding, self).__init__()
self.hidden_size = config.hidden_size
# Word embeddings.
self.word_embeddings = nn.Embedding(
config.padded_vocab_size,
self.hidden_size,
dtype=config.torch_dtype,
device=device
)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
return embeddings
class GEBModel(GEBPreTrainedModel):
def __init__(self, config: GEBConfig, device=None, empty_init=True):
super().__init__(config)
if empty_init:
init_method = skip_init
else:
init_method = default_init
init_kwargs = {}
if device is not None:
init_kwargs["device"] = device
self.embedding = init_method(Embedding, config, **init_kwargs)
self.num_layers = config.num_layers
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.kv_channels = config.kv_channels
# Rotary positional embeddings
self.seq_length = config.seq_length
rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
# self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl= True, device=device,
# dtype=config.torch_dtype)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
self.encoder = init_method(GEBTransformer, config, **init_kwargs)
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
dtype=config.torch_dtype, **init_kwargs)
self.pre_seq_len = config.pre_seq_len
self.prefix_projection = config.prefix_projection
if self.pre_seq_len is not None:
for param in self.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
def get_input_embeddings(self):
return self.embedding.word_embeddings
def get_prompt(self, batch_size, device, dtype=torch.half):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.num_layers * 2,
self.num_key_value_groups,
self.kv_channels
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
return past_key_values
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
dtype=inputs_embeds.dtype)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask], dim=-1)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# # Rotary positional embeddings
# rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# if position_ids is not None:
# rotary_pos_emb = rotary_pos_emb[position_ids]
# else:
# rotary_pos_emb = rotary_pos_emb[None, :seq_length]
# rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Rotary positional embeddings
# print(position_ids[0].tolist())
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
rotary_pos_emb = rotary_pos_emb[position_ids[0].tolist()]
# rotary_pos_emb = self.rotary_pos_emb(position_ids.shape[-1])
# # Rotary positional embeddings emb [seq_length, .., dim] no not need transpose
# rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# print('rotary_pos_emb:', rotary_pos_emb.size())
# if position_ids is not None:
# rotary_pos_emb = rotary_pos_emb[position_ids]
# print('rotary_pos_emb:', rotary_pos_emb.size())
# else:
# rotary_pos_emb = rotary_pos_emb[None, :seq_length]
# # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class GEBForCausalLM(GEBPreTrainedModel):
def __init__(self, config: GEBConfig, empty_init=True, device=None):
super().__init__(config)
self.max_sequence_length = config.max_length
self.transformer = GEBModel(config, empty_init=empty_init, device=device)
self.config = config
self.quantized = False
# if self.config.quantization_bit:
# self.quantize(self.config.quantization_bit, empty_init=True)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat(
[position_ids, new_position_id], dim=-1
)
model_kwargs["is_first_forward"] = False
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
is_first_forward: bool = True,
**kwargs
) -> dict:
# only last token for input_ids if past is not None
if position_ids is None:
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
if not is_first_forward:
if past_key_values is not None:
position_ids = position_ids[..., -1:]
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"position_ids": position_ids,
"attention_mask": attention_mask,
"return_last_logit": True,
"use_cache": use_cache
}
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return tuple(
(
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
)
for layer_past in past
)
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
prompt = tokenizer.build_prompt(query, history=history)
tokens = [tokenizer.get_command("<bos>")] + tokenizer.encode(prompt)
inputs = tokenizer.batch_encode_plus([tokens], return_tensors="pt", is_split_into_words=True)
inputs = inputs.to(self.device)
return inputs
# def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
# prompt = tokenizer.build_prompt(query, history=history)
# inputs = tokenizer([prompt], return_tensors="pt")
# # print(inputs)
# inputs = inputs.to(self.device)
# return inputs
@torch.inference_mode()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 512, num_beams=1,
do_sample=True, top_p=0.5, temperature=0.3, logits_processor=None, repetition_penalty = 1.15, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, "repetition_penalty":repetition_penalty, **kwargs}
prompt = tokenizer.build_prompt(query, history=[])
system = "You are a helpful assistant.\n"
system_ids = [
tokenizer.get_command("<bos>")
] + tokenizer.encode(text=system) + [
tokenizer.get_command("<eos>")]
prompt_ids = [
tokenizer.get_command("<bos>")
] + tokenizer.encode(
text=prompt,
add_special_tokens=False
) + [
tokenizer.get_command("<eos>")] + [
tokenizer.get_command("<bos>")]
tokens = system_ids + prompt_ids
inputs = tokenizer.batch_encode_plus([tokens], return_tensors="pt", is_split_into_words=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = inputs.to(device)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
return response, history