|
import torch |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers import AutoConfig |
|
from typing import Dict, List, Tuple, Union, Optional |
|
|
|
|
|
class FasterChatGLM(PreTrainedModel): |
|
def __init__(self, model_dir, kernel, *inputs, **kwargs): |
|
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
|
config.n_head = config.num_attention_heads |
|
config.n_embd = config.hidden_size |
|
config.n_layer = config.num_layers |
|
super().__init__(config, *inputs, **kwargs) |
|
self.kernel = kernel |
|
self.fake_reg = torch.nn.Linear(2, 2) |
|
self.position_encoding_2d = True |
|
|
|
def forward(self, input_ids, position_ids, attention_mask, past_key_values, *args, **kwargs): |
|
inputs_values = [input_ids, position_ids, attention_mask] |
|
if past_key_values is not None: |
|
inputs_values = inputs_values + past_key_values |
|
|
|
computed = self.kernel.infer(inputs_values) |
|
logits = computed[0] |
|
if len(computed) == 1: |
|
present_key_values = None |
|
else: |
|
present_key_values = computed[1:] |
|
|
|
return CausalLMOutputWithPast(logits=logits, past_key_values=present_key_values) |
|
|
|
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): |
|
attention_mask = torch.ones((1, context_length, context_length), device=device) |
|
attention_mask.tril_() |
|
attention_mask[..., :context_length - 1] = 1 |
|
attention_mask.unsqueeze_(1) |
|
attention_mask = (attention_mask < 0.5).bool() |
|
|
|
if self.position_encoding_2d: |
|
seq_length = seq.index(150004) |
|
position_ids = torch.arange(context_length, dtype=torch.long, device=device) |
|
if not gmask: |
|
position_ids[seq_length:] = mask_position |
|
block_position_ids = torch.cat(( |
|
torch.zeros(seq_length, dtype=torch.long, device=device), |
|
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 |
|
)) |
|
position_ids = torch.stack((position_ids, block_position_ids), dim=0) |
|
else: |
|
position_ids = torch.arange(context_length, dtype=torch.long, device=device) |
|
if not gmask: |
|
position_ids[context_length - 1:] = mask_position |
|
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
return attention_mask, position_ids |
|
|
|
def prepare_one_sample(self, input_id, mask_token, past, past_key_values, use_gmask): |
|
|
|
seq = input_id.tolist() |
|
mask_position = seq.index(mask_token) |
|
|
|
if mask_token not in seq: |
|
raise ValueError("You have to add either [MASK] or [gMASK] in your input") |
|
|
|
|
|
if past is not None or past_key_values is not None: |
|
context_length = seq.index(150004) |
|
last_token = input_id[-1].unsqueeze(-1).unsqueeze(0) |
|
proc_input_id = last_token |
|
if self.position_encoding_2d: |
|
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, |
|
device=input_id.device) |
|
else: |
|
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_id.device) |
|
|
|
attention_mask = torch.zeros(1, 1, 1, 1, device=input_id.device) |
|
else: |
|
proc_input_id = input_id.unsqueeze(0) |
|
attention_mask, position_ids = self.get_masks_and_position_ids( |
|
seq=seq, |
|
mask_position=mask_position, |
|
context_length=len(seq), |
|
device=input_id.device, |
|
gmask=use_gmask |
|
) |
|
|
|
return (proc_input_id.to(torch.int32), position_ids.to(torch.int32), |
|
attention_mask.to(torch.bool)) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
use_cache: bool = None, |
|
**kwargs |
|
) -> dict: |
|
|
|
MASK, gMASK = 150000, 150001 |
|
mask_token = MASK if MASK in input_ids else gMASK |
|
use_gmask = False if MASK in input_ids else gMASK |
|
|
|
batch_input_ids, batch_position_ids, batch_attention_mask = [], [], [] |
|
for input_id in input_ids: |
|
proc_input_id, position_id, attention_mask = self.prepare_one_sample( |
|
input_id, mask_token, past, past_key_values, use_gmask) |
|
batch_input_ids.append(proc_input_id) |
|
batch_position_ids.append(position_id) |
|
batch_attention_mask.append(attention_mask) |
|
|
|
batch_input_ids = torch.vstack(batch_input_ids) |
|
batch_position_ids = torch.vstack(batch_position_ids) |
|
batch_attention_mask = torch.vstack(batch_attention_mask) |
|
|
|
if past is None: |
|
past = past_key_values |
|
|
|
if past is not None or past_key_values is not None: |
|
self.kernel.set_context_mode(False) |
|
else: |
|
self.kernel.set_context_mode(self.config.use_cache) |
|
|
|
return { |
|
"input_ids": batch_input_ids, |
|
"past_key_values": past_key_values, |
|
"position_ids": batch_position_ids, |
|
"attention_mask": batch_attention_mask |
|
} |
|
|