bigmoyan's picture
Upload 12 files
acff406
raw
history blame
5.52 kB
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoConfig
from typing import Dict, List, Tuple, Union, Optional
class FasterChatGLM(PreTrainedModel):
def __init__(self, model_dir, kernel, *inputs, **kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.n_head = config.num_attention_heads
config.n_embd = config.hidden_size
config.n_layer = config.num_layers
super().__init__(config, *inputs, **kwargs)
self.kernel = kernel
self.fake_reg = torch.nn.Linear(2, 2)
self.position_encoding_2d = True
def forward(self, input_ids, position_ids, attention_mask, past_key_values, *args, **kwargs):
inputs_values = [input_ids, position_ids, attention_mask]
if past_key_values is not None:
inputs_values = inputs_values + past_key_values
computed = self.kernel.infer(inputs_values)
logits = computed[0]
if len(computed) == 1:
present_key_values = None
else:
present_key_values = computed[1:]
return CausalLMOutputWithPast(logits=logits, past_key_values=present_key_values)
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
attention_mask = torch.ones((1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., :context_length - 1] = 1
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()
if self.position_encoding_2d:
seq_length = seq.index(150004)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat((
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0)
return attention_mask, position_ids
def prepare_one_sample(self, input_id, mask_token, past, past_key_values, use_gmask):
seq = input_id.tolist()
mask_position = seq.index(mask_token)
if mask_token not in seq:
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
# only last token for input_ids if past is not None
if past is not None or past_key_values is not None:
context_length = seq.index(150004)
last_token = input_id[-1].unsqueeze(-1).unsqueeze(0) # 2 dim
proc_input_id = last_token
if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
device=input_id.device)
else:
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_id.device)
attention_mask = torch.zeros(1, 1, 1, 1, device=input_id.device)
else:
proc_input_id = input_id.unsqueeze(0)
attention_mask, position_ids = self.get_masks_and_position_ids(
seq=seq,
mask_position=mask_position,
context_length=len(seq),
device=input_id.device,
gmask=use_gmask
)
return (proc_input_id.to(torch.int32), position_ids.to(torch.int32),
attention_mask.to(torch.bool))
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = None,
**kwargs
) -> dict:
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
batch_input_ids, batch_position_ids, batch_attention_mask = [], [], []
for input_id in input_ids:
proc_input_id, position_id, attention_mask = self.prepare_one_sample(
input_id, mask_token, past, past_key_values, use_gmask)
batch_input_ids.append(proc_input_id)
batch_position_ids.append(position_id)
batch_attention_mask.append(attention_mask)
batch_input_ids = torch.vstack(batch_input_ids)
batch_position_ids = torch.vstack(batch_position_ids)
batch_attention_mask = torch.vstack(batch_attention_mask)
if past is None:
past = past_key_values
if past is not None or past_key_values is not None:
self.kernel.set_context_mode(False)
else:
self.kernel.set_context_mode(self.config.use_cache)
return {
"input_ids": batch_input_ids,
"past_key_values": past_key_values,
"position_ids": batch_position_ids,
"attention_mask": batch_attention_mask
}