File size: 3,185 Bytes
b0b5da4 026cf13 b0b5da4 026cf13 b0b5da4 026cf13 b0b5da4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import copy
class ChatTemplate:
cache = {}
roles = set()
def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
self.model = model
self.nl = nl
self.im_start = im_start
self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True)
self.im_end = im_end
self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True)
self.eos = [model._token_eos, self.im_end_nl[0]]
self.onenl = [self.im_end_nl[-1]]
tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True)
if len(tmp) == 1:
self.onenl.append(tmp[0])
self.onerl = model.tokenize(b'\r', add_bos=False, special=True)
self.nlnl = None
tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True)
if len(tmp) == 1:
self.nlnl = tmp[0]
print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl)
def _get(self, key: str):
if key in self.cache:
return copy.deepcopy(self.cache[key]) # 深拷贝一下
else:
value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True)
self.cache[key] = copy.deepcopy(value) # 深拷贝一下
return value
def _add_role(self, _role):
if _role:
self.roles.add('\n' + _role)
def eos_in_role(self, history: str, t_bot):
if not (history.endswith('\n') or history.endswith('\r')):
return 0
tmp = history.rstrip()
for _role in self.roles:
if tmp.endswith(_role):
n = len(t_bot)
for i in range(1, n): # 找出需要弃置的tokens长度
tmp = self.model.str_detokenize(t_bot[n - i:])
if tmp.rstrip().endswith(_role):
print('eos_in_role', t_bot[n - i:], repr(tmp))
return i
print('eos_in_role missing')
break
return 0
def eos_in_nlnl(self, history: str, t_bot):
if not (history.endswith('\n\n') or history.endswith('\n\r\n')):
return 0
n = len(t_bot)
for i in range(1, n): # 找出需要弃置的tokens长度
tmp = self.model.str_detokenize(t_bot[n - i:])
if tmp.endswith('\n\n') or tmp.endswith('\n\r\n'):
if tmp.startswith(']'): # 避免误判
return 0
print('eos_in_nlnl', t_bot[n - i:], repr(tmp))
return i
print('eos_in_nlnl missing')
return 0
def __call__(self, _role, prompt=None):
self._add_role(_role)
if prompt is None:
return self._get(_role)
# print(_role, prompt, self.cache)
prompt = self.im_start + _role + self.nl + prompt
prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl
# print(self.model.str_detokenize(prompt), prompt)
return prompt
|