Limour commited on
Commit
b9cb0bd
1 Parent(s): 0bcc3f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. mods/btn_com.py +1 -1
  2. mods/btn_status_bar.py +211 -1
mods/btn_com.py CHANGED
@@ -84,7 +84,7 @@ def init(cfg):
84
  # ========== 查看末尾的换行符 ==========
85
  print('history', repr(history))
86
  # ========== 给 kv_cache 加上输出结束符 ==========
87
- model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
88
  t_bot.extend(chat_template.im_end_nl)
89
 
90
  cfg['btn_com'] = btn_com
 
84
  # ========== 查看末尾的换行符 ==========
85
  print('history', repr(history))
86
  # ========== 给 kv_cache 加上输出结束符 ==========
87
+ model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard, chat_template.im_start_token)
88
  t_bot.extend(chat_template.im_end_nl)
89
 
90
  cfg['btn_com'] = btn_com
mods/btn_status_bar.py CHANGED
@@ -1,4 +1,214 @@
 
 
 
1
  def init(cfg):
2
  chat_template = cfg['chat_template']
3
  model = cfg['model']
4
- lock = cfg['session_lock']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
  def init(cfg):
5
  chat_template = cfg['chat_template']
6
  model = cfg['model']
7
+ s_info = cfg['s_info']
8
+ lock = cfg['session_lock']
9
+
10
+ # ========== 预处理 key、desc ==========
11
+ def str_tokenize(s):
12
+ s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False)
13
+ if s[0] in chat_template.onenl:
14
+ return s[1:]
15
+ else:
16
+ return s
17
+
18
+ text_format = cfg['text_format']
19
+ for x in cfg['btn_status_bar_list']:
20
+ x['key'] = text_format(x['key'],
21
+ char=cfg['role_char'].value,
22
+ user=cfg['role_usr'].value)
23
+ x['key_t'] = str_tokenize(x['key'])
24
+ x['desc'] = text_format(x['desc'],
25
+ char=cfg['role_char'].value,
26
+ user=cfg['role_usr'].value)
27
+ if x['desc']:
28
+ x['desc_t'] = str_tokenize(x['desc'])
29
+
30
+ # ========== 预处理 构造函数 mask ==========
31
+ def btn_status_bar_fn_mask():
32
+ _shape1d = model.scores.shape[-1]
33
+ mask = np.full((_shape1d,), -np.inf, dtype=np.single)
34
+ return mask
35
+
36
+ # ========== 预处理 构造函数 数字 ==========
37
+ def btn_status_bar_fn_int(unit: str):
38
+ t_int = str_tokenize('0123456789')
39
+ assert len(t_int) == 10
40
+ fn_int_mask = btn_status_bar_fn_mask()
41
+ fn_int_mask[chat_template.eos] = 0
42
+ fn_int_mask[t_int] = 0
43
+ if unit:
44
+ unit_t = str_tokenize(unit)
45
+ fn_int_mask[unit_t[0]] = 0
46
+
47
+ def logits_processor(_input_ids, logits):
48
+ return logits + fn_int_mask
49
+
50
+ def inner(eval_t, sample_t):
51
+ retn = []
52
+ while True:
53
+ token = sample_t(logits_processor)
54
+ # ========== 不是数字就结束 ==========
55
+ if token in chat_template.eos:
56
+ break
57
+ if unit and token == unit_t[0]:
58
+ break
59
+ # ========== 是数字就继续 ==========
60
+ retn.append(token)
61
+ eval_t([token])
62
+
63
+ if unit:
64
+ eval_t(unit_t) # 添加单位
65
+ retn.extend(unit_t)
66
+
67
+ return model.str_detokenize(retn)
68
+
69
+ return inner
70
+
71
+ # ========== 预处理 构造函数 集合 ==========
72
+ def btn_status_bar_fn_set(value):
73
+ value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)}
74
+ fn_set_mask = btn_status_bar_fn_mask()
75
+ fn_set_mask[list(value_t.keys())] = 0
76
+
77
+ def logits_processor(_input_ids, logits):
78
+ return logits + fn_set_mask
79
+
80
+ def inner(eval_t, sample_t):
81
+ token = sample_t(logits_processor)
82
+ eval_t(value_t[token][0])
83
+ return value_t[token][1]
84
+
85
+ return inner
86
+
87
+ # ========== 预处理 构造函数 字符串 ==========
88
+ def btn_status_bar_fn_str():
89
+ def inner(eval_t, sample_t):
90
+ retn = []
91
+ tmp = ''
92
+ while True:
93
+ token = sample_t(None)
94
+ if token in chat_template.eos:
95
+ break
96
+ retn.append(token)
97
+ tmp = model.str_detokenize(retn)
98
+ if tmp.endswith('\n') or tmp.endswith('\r'):
99
+ break
100
+ # ========== 继续 ==========
101
+ eval_t([token])
102
+ return tmp.strip()
103
+
104
+ return inner
105
+
106
+ # ========== 预处理 value ==========
107
+ for x in cfg['btn_status_bar_list']:
108
+ for y in x['combine']:
109
+ if y['prefix']:
110
+ y['prefix_t'] = str_tokenize(y['prefix'])
111
+
112
+ if y['type'] == 'int':
113
+ y['fn'] = btn_status_bar_fn_int(y['unit'])
114
+ elif y['type'] == 'set':
115
+ y['fn'] = btn_status_bar_fn_set(y['value'])
116
+ elif y['type'] == 'str':
117
+ y['fn'] = btn_status_bar_fn_str()
118
+ else:
119
+ pass
120
+
121
+ # ========== 添加分隔标记 ==========
122
+ for i, x in enumerate(cfg['btn_status_bar_list']):
123
+ if i == 0: # 跳过第一个
124
+ continue
125
+ x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t']
126
+
127
+ del x # 避免干扰
128
+ del y
129
+
130
+ # print(cfg['btn_status_bar_list'])
131
+
132
+ # ========== 输出状态栏 ==========
133
+ def btn_status_bar(_n_keep, _n_discard,
134
+ _temperature, _repeat_penalty, _frequency_penalty,
135
+ _presence_penalty, _repeat_last_n, _top_k,
136
+ _top_p, _min_p, _typical_p,
137
+ _tfs_z, _mirostat_mode, _mirostat_eta,
138
+ _mirostat_tau, _usr, _char,
139
+ _rag, _max_tokens):
140
+ with lock:
141
+ if not cfg['session_active']:
142
+ raise RuntimeError
143
+ if cfg['btn_stop_status']:
144
+ yield [], model.venv_info
145
+ return
146
+
147
+ # ========== 临时的eval和sample ==========
148
+ def eval_t(tokens):
149
+ return model.eval_t(
150
+ tokens=tokens,
151
+ n_keep=_n_keep,
152
+ n_discard=_n_discard,
153
+ im_start=chat_template.im_start_token
154
+ )
155
+
156
+ def sample_t(logits_processor):
157
+ return model.sample_t(
158
+ top_k=_top_k,
159
+ top_p=_top_p,
160
+ min_p=_min_p,
161
+ typical_p=_typical_p,
162
+ temp=_temperature,
163
+ repeat_penalty=_repeat_penalty,
164
+ repeat_last_n=_repeat_last_n,
165
+ frequency_penalty=_frequency_penalty,
166
+ presence_penalty=_presence_penalty,
167
+ tfs_z=_tfs_z,
168
+ mirostat_mode=_mirostat_mode,
169
+ mirostat_tau=_mirostat_tau,
170
+ mirostat_eta=_mirostat_eta,
171
+ logits_processor=logits_processor
172
+ )
173
+
174
+ # ========== 初始化输出模版 ==========
175
+ model.venv_create('status') # 创建隔离环境
176
+ eval_t(chat_template('状态')) # 开始标记
177
+ # ========== 流式输出 ==========
178
+ df = [] # 清空
179
+ for _x in cfg['btn_status_bar_list']:
180
+ # ========== 属性 ==========
181
+ df.append([_x['key'], ''])
182
+ eval_t(_x['key_t'])
183
+ if _x['desc']:
184
+ eval_t(_x['desc_t'])
185
+ yield df, model.venv_info
186
+ # ========== 值 ==========
187
+ for _y in _x['combine']:
188
+ if _y['prefix']:
189
+ if df[-1][-1]:
190
+ df[-1][-1] += _y['prefix']
191
+ else:
192
+ df[-1][-1] += _y['prefix'].lstrip(':')
193
+ eval_t(_y['prefix_t'])
194
+ df[-1][-1] += _y['fn'](eval_t, sample_t)
195
+ yield df, model.venv_info
196
+ eval_t(chat_template.im_end_nl) # 结束标记
197
+ # ========== 清理上一次生成的状态栏 ==========
198
+ model.venv_remove('status', keep_last=1)
199
+ yield df, model.venv_info
200
+
201
+ cfg['btn_status_bar_fn'] = {
202
+ 'fn': btn_status_bar,
203
+ 'inputs': cfg['setting'],
204
+ 'outputs': [cfg['status_bar'], s_info]
205
+ }
206
+ cfg['btn_status_bar_fn'].update(cfg['btn_concurrency'])
207
+
208
+ cfg['btn_status_bar'].click(
209
+ **cfg['btn_start']
210
+ ).success(
211
+ **cfg['btn_status_bar_fn']
212
+ ).success(
213
+ **cfg['btn_finish']
214
+ )