| from typing import Optional, Sequence, Generator |
|
|
| from llama_cpp import Llama, LogitsProcessorList, LlamaGrammar, llama_cpp, npt, np, StoppingCriteriaList |
| from ctypes import POINTER |
|
|
| from KMP_list import kmp_search, compute_lps_array |
|
|
|
|
| def is_UTF8_incomplete(all_text): |
| multibyte_fix = 0 |
| if len(all_text) < 3: |
| all_text = b'000' + all_text |
| for k, char in enumerate(all_text[-3:]): |
| k = 3 - k |
| for num, pattern in [(2, 192), (3, 224), (4, 240)]: |
| |
| if num > k and pattern & char == pattern: |
| multibyte_fix = num - k |
| return multibyte_fix |
|
|
|
|
| def get_complete_UTF8(all_text): |
| multibyte_fix = is_UTF8_incomplete(all_text) |
| if multibyte_fix > 0: |
| multibyte_fix = multibyte_fix - 3 |
| return all_text[:multibyte_fix].decode("utf-8") |
| else: |
| return all_text.decode("utf-8") |
|
|
|
|
| class StreamingLLM(Llama): |
| def __init__(self, model_path: str, **kwargs): |
| super().__init__(model_path, **kwargs) |
| self._venv_init() |
|
|
| def str_detokenize(self, tokens) -> str: |
| return get_complete_UTF8(self.detokenize(tokens)) |
|
|
| def kv_cache_seq_trim(self): |
| self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) |
|
|
| def _venv_init(self): |
| self.venv = [0] |
| self.venv_idx_map = [] |
|
|
| def venv_create(self, name: str): |
| self.venv.append(0) |
| self.venv_idx_map.append(name) |
| return name |
|
|
| def venv_disband(self, name_set): |
| if len(self.venv) <= 1: |
| return False |
| name_set = {x for x in name_set if x in self.venv_idx_map} |
| if not name_set: |
| return False |
| while self.venv_idx_map: |
| if self.venv_idx_map[0] in name_set: |
| self.venv_idx_map.pop(0) |
| tmp = self.venv.pop(1) |
| self.venv[0] += tmp |
| else: |
| break |
| return True |
|
|
| def venv_revision(self, name: str): |
| if len(self.venv) <= 1: |
| return False |
| if name not in self.venv_idx_map: |
| return False |
| _s = 0 |
| while self.venv_idx_map: |
| if self.venv_idx_map[-1] == name: |
| break |
| self.venv_idx_map.pop() |
| _s += self.venv.pop() |
| if _s: |
| self.n_tokens -= min(_s, self.n_tokens) |
| self.kv_cache_seq_trim() |
| return True |
|
|
| def venv_remove(self, name: str): |
| if len(self.venv) <= 1: |
| return False |
| if name not in self.venv_idx_map: |
| return False |
| venv_idx = self.venv_idx_map.index(name) + 1 |
| while self.venv_idx_map: |
| self.venv_idx_map.pop(venv_idx - 1) |
| if venv_idx == len(self.venv) - 1: |
| |
| self.n_tokens -= min(self.venv.pop(), self.n_tokens) |
| self.kv_cache_seq_trim() |
| break |
| else: |
| |
| n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv))) |
| n_discard = self.venv.pop(venv_idx) |
| self.kv_cache_seq_ltrim(n_keep, n_discard) |
| try: |
| venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1 |
| except ValueError: |
| break |
| return True |
|
|
| def venv_pop_token(self): |
| self.n_tokens -= 1 |
| self.venv[-1] -= 1 |
| self.kv_cache_seq_trim() |
|
|
| @property |
| def venv_info(self): |
| return str((self.n_tokens, self.venv, self.venv_idx_map)) |
|
|
| def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None): |
| if n_past < 0: |
| n_past = self.n_tokens |
| if im_start is not None: |
| lps = compute_lps_array(im_start) |
| _idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps) |
| if _idx >= n_keep: |
| n_discard = _idx - n_keep |
| else: |
| _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps) |
| if _idx >= n_keep: |
| n_keep = _idx + len(im_start) |
| self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard) |
| self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard) |
| self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past] |
| self.n_tokens = n_past - n_discard |
|
|
| def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None): |
| if self._n_ctx < self.n_tokens + len(tokens): |
| tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx) |
| self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start=im_start) |
| for i in range(0, len(tokens), self.n_batch): |
| batch = tokens[i: i + self.n_batch] |
| n_past = self.n_tokens |
| n_tokens = len(batch) |
| self._batch.set_batch( |
| batch=batch, n_past=n_past, logits_all=self.context_params.logits_all |
| ) |
| self._ctx.decode(self._batch) |
| |
| self.input_ids[n_past: n_past + n_tokens] = batch |
| |
| rows = n_tokens |
| cols = self._n_vocab |
| offset = ( |
| 0 if self.context_params.logits_all else n_tokens - 1 |
| ) |
| self.scores[n_past + offset: n_past + n_tokens, :].reshape(-1)[ |
| : |
| ] = self._ctx.get_logits()[offset * cols: rows * cols] |
| |
| self.n_tokens += n_tokens |
| self.venv[-1] += n_tokens |
| return self.n_tokens |
|
|
| def sample_t( |
| self, |
| top_k: int = 40, |
| top_p: float = 0.95, |
| min_p: float = 0.05, |
| typical_p: float = 1.0, |
| temp: float = 0.80, |
| repeat_penalty: float = 1.1, |
| repeat_last_n: int = 64, |
| frequency_penalty: float = 0.0, |
| presence_penalty: float = 0.0, |
| tfs_z: float = 1.0, |
| mirostat_mode: int = 0, |
| mirostat_eta: float = 0.1, |
| mirostat_tau: float = 5.0, |
| penalize_nl: bool = True, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| grammar: Optional[LlamaGrammar] = None, |
| ): |
| last_n_tokens_data = [llama_cpp.llama_token(0)] * max( |
| 0, repeat_last_n - self.n_tokens |
| ) + self._input_ids[-repeat_last_n:].tolist() |
| last_n_tokens_size = len(last_n_tokens_data) |
| n_vocab = self._n_vocab |
| n_ctx = self._n_ctx |
| top_k = n_vocab if top_k <= 0 else top_k |
| last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size |
| last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( |
| *last_n_tokens_data |
| ) |
| logits: npt.NDArray[np.single] = self.scores[self.n_tokens - 1: self.n_tokens, :].ravel() |
|
|
| if logits_processor is not None: |
| logits[:] = logits_processor(self._input_ids, logits) |
|
|
| self._candidates.copy_logits(logits) |
| self._ctx.sample_repetition_penalties( |
| candidates=self._candidates, |
| last_tokens_data=last_n_tokens_data_c, |
| penalty_last_n=last_n_tokens_size, |
| penalty_repeat=repeat_penalty, |
| penalty_freq=frequency_penalty, |
| penalty_present=presence_penalty, |
| ) |
| if not penalize_nl: |
| nl_logit = logits[self._token_nl] |
| self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( |
| nl_logit |
| ) |
|
|
| if grammar is not None: |
| self._ctx.sample_grammar( |
| candidates=self._candidates, |
| grammar=grammar, |
| ) |
|
|
| if temp < 0.0: |
| self._ctx.sample_softmax(candidates=self._candidates) |
| id_ = self._candidates.candidates.data[0].id |
| elif temp == 0.0: |
| id_ = self._ctx.sample_token_greedy(candidates=self._candidates) |
| elif mirostat_mode == 1: |
| self._ctx.sample_temp(candidates=self._candidates, temp=temp) |
| id_ = self._ctx.sample_token_mirostat( |
| candidates=self._candidates, |
| tau=mirostat_tau, |
| eta=mirostat_eta, |
| mu=2.0 * mirostat_tau, |
| m=100, |
| ) |
| elif mirostat_mode == 2: |
| self._ctx.sample_temp(candidates=self._candidates, temp=temp) |
| id_ = self._ctx.sample_token_mirostat_v2( |
| candidates=self._candidates, |
| tau=mirostat_tau, |
| eta=mirostat_eta, |
| mu=2.0 * mirostat_tau, |
| ) |
| else: |
| self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) |
| self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) |
| self._ctx.sample_typical( |
| candidates=self._candidates, p=typical_p, min_keep=1 |
| ) |
| self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) |
| self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) |
| self._ctx.sample_temp(candidates=self._candidates, temp=temp) |
| id_ = self._ctx.sample_token(candidates=self._candidates) |
| if grammar is not None: |
| self._ctx.grammar_accept_token(grammar=grammar, token=id_) |
| return id_ |
|
|
| def generate_t( |
| self, |
| tokens: Sequence[int], |
| n_keep, |
| n_discard: int = 256, |
| im_start=None, |
| top_k: int = 40, |
| top_p: float = 0.95, |
| min_p: float = 0.05, |
| typical_p: float = 1.0, |
| temp: float = 0.80, |
| repeat_penalty: float = 1.1, |
| repeat_last_n: int = 64, |
| frequency_penalty: float = 0.0, |
| presence_penalty: float = 0.0, |
| tfs_z: float = 1.0, |
| mirostat_mode: int = 0, |
| mirostat_tau: float = 5.0, |
| mirostat_eta: float = 0.1, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| grammar: Optional[LlamaGrammar] = None, |
| ) -> Generator[int, Optional[Sequence[int]], None]: |
| typical_p = float(typical_p) |
| frequency_penalty = float(frequency_penalty) |
| presence_penalty = float(presence_penalty) |
| tfs_z = float(tfs_z) |
| mirostat_tau = float(mirostat_tau) |
| while True: |
| self.eval_t(tokens, n_keep, n_discard, im_start=im_start) |
| token = self.sample_t( |
| top_k=top_k, |
| top_p=top_p, |
| min_p=min_p, |
| typical_p=typical_p, |
| temp=temp, |
| repeat_penalty=repeat_penalty, |
| repeat_last_n=repeat_last_n, |
| frequency_penalty=frequency_penalty, |
| presence_penalty=presence_penalty, |
| tfs_z=tfs_z, |
| mirostat_mode=mirostat_mode, |
| mirostat_tau=mirostat_tau, |
| mirostat_eta=mirostat_eta, |
| logits_processor=logits_processor, |
| grammar=grammar, |
| ) |
| if stopping_criteria is not None and stopping_criteria( |
| self._input_ids, self._scores[-1, :] |
| ): |
| return |
| tokens_or_none = yield token |
| tokens = [token] |
| if tokens_or_none is not None: |
| tokens.extend(tokens_or_none) |
|
|
| def load_session(self, filepath: str): |
| n_tokens = POINTER(llama_cpp.c_size_t)(llama_cpp.c_size_t(0)) |
| tokens = (llama_cpp.llama_token * self.n_ctx())() |
| retn = llama_cpp.llama_load_session_file(self._ctx.ctx, |
| filepath.encode('utf-8'), |
| tokens, |
| self.n_ctx(), |
| n_tokens) |
| self.n_tokens = n_tokens.contents.value |
| self.input_ids[:self.n_tokens] = tokens[:self.n_tokens] |
| self._venv_init() |
| return retn |
|
|
| def save_session(self, filepath: str): |
| tokens = self._input_ids.tolist() |
| tokens = (llama_cpp.llama_token * len(tokens))(*tokens) |
| return llama_cpp.llama_save_session_file(self._ctx.ctx, filepath.encode('utf-8'), tokens, self.n_tokens) |
|
|