File size: 5,523 Bytes
acff406 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoConfig
from typing import Dict, List, Tuple, Union, Optional
class FasterChatGLM(PreTrainedModel):
def __init__(self, model_dir, kernel, *inputs, **kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.n_head = config.num_attention_heads
config.n_embd = config.hidden_size
config.n_layer = config.num_layers
super().__init__(config, *inputs, **kwargs)
self.kernel = kernel
self.fake_reg = torch.nn.Linear(2, 2)
self.position_encoding_2d = True
def forward(self, input_ids, position_ids, attention_mask, past_key_values, *args, **kwargs):
inputs_values = [input_ids, position_ids, attention_mask]
if past_key_values is not None:
inputs_values = inputs_values + past_key_values
computed = self.kernel.infer(inputs_values)
logits = computed[0]
if len(computed) == 1:
present_key_values = None
else:
present_key_values = computed[1:]
return CausalLMOutputWithPast(logits=logits, past_key_values=present_key_values)
def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
attention_mask = torch.ones((1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., :context_length - 1] = 1
attention_mask.unsqueeze_(1)
attention_mask = (attention_mask < 0.5).bool()
if self.position_encoding_2d:
seq_length = seq.index(150004)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat((
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0)
return attention_mask, position_ids
def prepare_one_sample(self, input_id, mask_token, past, past_key_values, use_gmask):
seq = input_id.tolist()
mask_position = seq.index(mask_token)
if mask_token not in seq:
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
# only last token for input_ids if past is not None
if past is not None or past_key_values is not None:
context_length = seq.index(150004)
last_token = input_id[-1].unsqueeze(-1).unsqueeze(0) # 2 dim
proc_input_id = last_token
if self.position_encoding_2d:
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
device=input_id.device)
else:
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_id.device)
attention_mask = torch.zeros(1, 1, 1, 1, device=input_id.device)
else:
proc_input_id = input_id.unsqueeze(0)
attention_mask, position_ids = self.get_masks_and_position_ids(
seq=seq,
mask_position=mask_position,
context_length=len(seq),
device=input_id.device,
gmask=use_gmask
)
return (proc_input_id.to(torch.int32), position_ids.to(torch.int32),
attention_mask.to(torch.bool))
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = None,
**kwargs
) -> dict:
MASK, gMASK = 150000, 150001
mask_token = MASK if MASK in input_ids else gMASK
use_gmask = False if MASK in input_ids else gMASK
batch_input_ids, batch_position_ids, batch_attention_mask = [], [], []
for input_id in input_ids:
proc_input_id, position_id, attention_mask = self.prepare_one_sample(
input_id, mask_token, past, past_key_values, use_gmask)
batch_input_ids.append(proc_input_id)
batch_position_ids.append(position_id)
batch_attention_mask.append(attention_mask)
batch_input_ids = torch.vstack(batch_input_ids)
batch_position_ids = torch.vstack(batch_position_ids)
batch_attention_mask = torch.vstack(batch_attention_mask)
if past is None:
past = past_key_values
if past is not None or past_key_values is not None:
self.kernel.set_context_mode(False)
else:
self.kernel.set_context_mode(self.config.use_cache)
return {
"input_ids": batch_input_ids,
"past_key_values": past_key_values,
"position_ids": batch_position_ids,
"attention_mask": batch_attention_mask
}
|