|
import numpy as np |
|
|
|
|
|
def init(cfg): |
|
chat_template = cfg['chat_template'] |
|
model = cfg['model'] |
|
s_info = cfg['s_info'] |
|
lock = cfg['session_lock'] |
|
|
|
|
|
def str_tokenize(s): |
|
s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False) |
|
if s[0] in chat_template.onenl: |
|
return s[1:] |
|
else: |
|
return s |
|
|
|
text_format = cfg['text_format'] |
|
for x in cfg['btn_status_bar_list']: |
|
x['key'] = text_format(x['key'], |
|
char=cfg['role_char'].value, |
|
user=cfg['role_usr'].value) |
|
x['key_t'] = str_tokenize(x['key']) |
|
x['desc'] = text_format(x['desc'], |
|
char=cfg['role_char'].value, |
|
user=cfg['role_usr'].value) |
|
if x['desc']: |
|
x['desc_t'] = str_tokenize(x['desc']) |
|
|
|
|
|
def btn_status_bar_fn_mask(): |
|
_shape1d = model.scores.shape[-1] |
|
mask = np.full((_shape1d,), -np.inf, dtype=np.single) |
|
return mask |
|
|
|
|
|
def btn_status_bar_fn_int(unit: str): |
|
t_int = str_tokenize('0123456789') |
|
assert len(t_int) == 10 |
|
fn_int_mask = btn_status_bar_fn_mask() |
|
fn_int_mask[chat_template.eos] = 0 |
|
fn_int_mask[t_int] = 0 |
|
if unit: |
|
unit_t = str_tokenize(unit) |
|
fn_int_mask[unit_t[0]] = 0 |
|
|
|
def logits_processor(_input_ids, logits): |
|
return logits + fn_int_mask |
|
|
|
def inner(eval_t, sample_t): |
|
retn = [] |
|
while True: |
|
token = sample_t(logits_processor) |
|
|
|
if token in chat_template.eos: |
|
break |
|
if unit and token == unit_t[0]: |
|
break |
|
|
|
retn.append(token) |
|
eval_t([token]) |
|
|
|
if unit: |
|
eval_t(unit_t) |
|
retn.extend(unit_t) |
|
|
|
return model.str_detokenize(retn) |
|
|
|
return inner |
|
|
|
|
|
def btn_status_bar_fn_set(value): |
|
value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)} |
|
fn_set_mask = btn_status_bar_fn_mask() |
|
fn_set_mask[list(value_t.keys())] = 0 |
|
|
|
def logits_processor(_input_ids, logits): |
|
return logits + fn_set_mask |
|
|
|
def inner(eval_t, sample_t): |
|
token = sample_t(logits_processor) |
|
eval_t(value_t[token][0]) |
|
return value_t[token][1] |
|
|
|
return inner |
|
|
|
|
|
def btn_status_bar_fn_str(): |
|
def inner(eval_t, sample_t): |
|
retn = [] |
|
tmp = '' |
|
while True: |
|
token = sample_t(None) |
|
if token in chat_template.eos: |
|
break |
|
retn.append(token) |
|
tmp = model.str_detokenize(retn) |
|
if tmp.endswith('\n') or tmp.endswith('\r'): |
|
break |
|
|
|
eval_t([token]) |
|
return tmp.strip() |
|
|
|
return inner |
|
|
|
|
|
for x in cfg['btn_status_bar_list']: |
|
for y in x['combine']: |
|
if y['prefix']: |
|
y['prefix_t'] = str_tokenize(y['prefix']) |
|
|
|
if y['type'] == 'int': |
|
y['fn'] = btn_status_bar_fn_int(y['unit']) |
|
elif y['type'] == 'set': |
|
y['fn'] = btn_status_bar_fn_set(y['value']) |
|
elif y['type'] == 'str': |
|
y['fn'] = btn_status_bar_fn_str() |
|
else: |
|
pass |
|
|
|
|
|
for i, x in enumerate(cfg['btn_status_bar_list']): |
|
if i == 0: |
|
continue |
|
x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t'] |
|
|
|
del x |
|
del y |
|
|
|
|
|
|
|
|
|
def btn_status_bar(_n_keep, _n_discard, |
|
_temperature, _repeat_penalty, _frequency_penalty, |
|
_presence_penalty, _repeat_last_n, _top_k, |
|
_top_p, _min_p, _typical_p, |
|
_tfs_z, _mirostat_mode, _mirostat_eta, |
|
_mirostat_tau, _usr, _char, |
|
_rag, _max_tokens): |
|
with lock: |
|
if not cfg['session_active']: |
|
raise RuntimeError |
|
if cfg['btn_stop_status']: |
|
yield [], model.venv_info |
|
return |
|
|
|
|
|
def eval_t(tokens): |
|
return model.eval_t( |
|
tokens=tokens, |
|
n_keep=_n_keep, |
|
n_discard=_n_discard, |
|
im_start=chat_template.im_start_token |
|
) |
|
|
|
def sample_t(logits_processor): |
|
return model.sample_t( |
|
top_k=_top_k, |
|
top_p=_top_p, |
|
min_p=_min_p, |
|
typical_p=_typical_p, |
|
temp=_temperature, |
|
repeat_penalty=_repeat_penalty, |
|
repeat_last_n=_repeat_last_n, |
|
frequency_penalty=_frequency_penalty, |
|
presence_penalty=_presence_penalty, |
|
tfs_z=_tfs_z, |
|
mirostat_mode=_mirostat_mode, |
|
mirostat_tau=_mirostat_tau, |
|
mirostat_eta=_mirostat_eta, |
|
logits_processor=logits_processor |
|
) |
|
|
|
|
|
model.venv_create('status') |
|
eval_t(chat_template('状态')) |
|
|
|
df = [] |
|
for _x in cfg['btn_status_bar_list']: |
|
|
|
df.append([_x['key'], '']) |
|
eval_t(_x['key_t']) |
|
if _x['desc']: |
|
eval_t(_x['desc_t']) |
|
yield df, model.venv_info |
|
|
|
for _y in _x['combine']: |
|
if _y['prefix']: |
|
if df[-1][-1]: |
|
df[-1][-1] += _y['prefix'] |
|
else: |
|
df[-1][-1] += _y['prefix'].lstrip(':') |
|
eval_t(_y['prefix_t']) |
|
df[-1][-1] += _y['fn'](eval_t, sample_t) |
|
yield df, model.venv_info |
|
eval_t(chat_template.im_end_nl) |
|
|
|
model.venv_remove('status', keep_last=1) |
|
yield df, model.venv_info |
|
|
|
cfg['btn_status_bar_fn'] = { |
|
'fn': btn_status_bar, |
|
'inputs': cfg['setting'], |
|
'outputs': [cfg['status_bar'], s_info] |
|
} |
|
cfg['btn_status_bar_fn'].update(cfg['btn_concurrency']) |
|
|
|
cfg['btn_status_bar'].click( |
|
**cfg['btn_start'] |
|
).success( |
|
**cfg['btn_status_bar_fn'] |
|
).success( |
|
**cfg['btn_finish'] |
|
) |
|
|