|
import torch |
|
import onnxruntime |
|
import numpy as np |
|
from sentencepiece import SentencePieceProcessor |
|
from typing import List |
|
import os |
|
import logging |
|
import gc |
|
|
|
from .base_interface import BaseLLMInterface |
|
|
|
from ChatApp.app_modules.utils import ( |
|
is_stop_word_or_prefix, |
|
convert_to_markdown, |
|
shared_state, |
|
) |
|
|
|
|
|
class Tokenizer: |
|
def __init__(self, model_path: str): |
|
|
|
assert os.path.isfile(model_path), model_path |
|
self.sp_model = SentencePieceProcessor(model_file=model_path) |
|
|
|
|
|
self.n_words: int = self.sp_model.vocab_size() |
|
self.bos_id: int = self.sp_model.bos_id() |
|
self.eos_id: int = self.sp_model.eos_id() |
|
self.pad_id: int = self.sp_model.pad_id() |
|
|
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() |
|
|
|
def encode(self, s: str, bos: bool, eos: bool) -> List[int]: |
|
assert type(s) is str |
|
t = self.sp_model.encode(s) |
|
if bos: |
|
t = [self.bos_id] + t |
|
if eos: |
|
t = t + [self.eos_id] |
|
return t |
|
|
|
def decode(self, t: List[int]) -> str: |
|
return self.sp_model.decode(t) |
|
|
|
|
|
class LlamaOnnxInterface(BaseLLMInterface): |
|
def __init__(self, onnx_file="", embedding_file="", tokenizer_path=""): |
|
super().__init__() |
|
|
|
self.onnx_file = onnx_file |
|
self.embedding_file = embedding_file |
|
self.tokenizer_path = tokenizer_path |
|
|
|
self.total_count = 0 |
|
|
|
def initialize(self): |
|
|
|
|
|
logging.info(f"Creating ONNX session for [{self.onnx_file}]") |
|
options = onnxruntime.SessionOptions() |
|
self.llm_session = onnxruntime.InferenceSession( |
|
self.onnx_file, |
|
sess_options=options, |
|
providers=[ |
|
"DmlExecutionProvider", |
|
"CUDAExecutionProvider", |
|
"CPUExecutionProvider", |
|
], |
|
) |
|
|
|
|
|
data_type_str = self.llm_session.get_inputs()[0].type |
|
if data_type_str == "tensor(float16)": |
|
self.data_type = np.float16 |
|
elif data_type_str == "tensor(float32)": |
|
self.data_type = np.float32 |
|
else: |
|
raise Exception(f"Unknown data type {data_type_str}") |
|
|
|
logging.info(f"Detected Data Type [{self.data_type}]") |
|
|
|
|
|
for inputs_meta in self.llm_session._inputs_meta: |
|
if inputs_meta.name == "x": |
|
x_shape = inputs_meta.shape |
|
elif inputs_meta.name == "attn_mask": |
|
attn_mask_shape = inputs_meta.shape |
|
elif inputs_meta.name == "k_cache": |
|
k_cache_shape = inputs_meta.shape |
|
|
|
self.hidden_size = x_shape[2] |
|
self.max_seq_len = attn_mask_shape[1] |
|
self.n_layers = k_cache_shape[1] |
|
self.n_heads = k_cache_shape[3] |
|
|
|
|
|
self.tokenizer = Tokenizer(model_path=self.tokenizer_path) |
|
|
|
|
|
logging.info( |
|
f"Creating the Embedding Layer. Size [{self.tokenizer.n_words}, {self.hidden_size}]" |
|
) |
|
self.embeddingLayer = torch.nn.Embedding( |
|
self.tokenizer.n_words, self.hidden_size |
|
) |
|
|
|
|
|
d = torch.load(self.embedding_file) |
|
self.embeddingLayer.load_state_dict(d) |
|
self.embeddingLayer.eval() |
|
|
|
|
|
self.attn_mask = -10000.0 * torch.triu( |
|
torch.ones(attn_mask_shape), diagonal=1 |
|
).cpu().detach().numpy().astype(self.data_type) |
|
|
|
|
|
self.head_dim = int(self.hidden_size / self.n_heads) |
|
self.k_cache = np.zeros( |
|
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim], |
|
dtype=self.data_type, |
|
) |
|
self.v_cache = np.zeros( |
|
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim], |
|
dtype=self.data_type, |
|
) |
|
|
|
def shutdown(self): |
|
pass |
|
|
|
def generate_prompt_with_history(self, text, history, tokenizer, max_length=2048): |
|
prompt = "[|Human|]Hey there I am a human that would like to have\ |
|
a conversation with you.\n[|AI|]Sure, I am happy to answer most questions\ |
|
\n[|Human|]Great, I insist that we take turns.\n[|AI|]I agree, we should\ |
|
take turns.\n[|Human|]Great, can we also keep answers short\n[|AI|]Yes, \ |
|
short answers are usually best" |
|
|
|
history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history] |
|
history.append("\n[|Human|]{}\n[|AI|]".format(text)) |
|
history_text = "" |
|
flag = False |
|
for x in history[::-1]: |
|
|
|
if ( |
|
len( |
|
self.tokenizer.encode( |
|
prompt + history_text + x, bos=True, eos=False |
|
) |
|
) |
|
<= max_length |
|
): |
|
history_text = x + history_text |
|
flag = True |
|
else: |
|
break |
|
if flag: |
|
return prompt + history_text, torch.tensor( |
|
self.tokenizer.encode(prompt + history_text, bos=True, eos=False) |
|
).unsqueeze(0) |
|
else: |
|
return None |
|
|
|
def sample_logits( |
|
self, |
|
logits: np.ndarray, |
|
sampling_method: str = "greedy", |
|
sampling_value: float = None, |
|
temperature: float = 1.0, |
|
) -> np.ndarray: |
|
if temperature == 0 or sampling_method == "greedy": |
|
next_token = np.argmax(logits, axis=-1).astype(np.int64) |
|
|
|
elif sampling_method == "top_k" or sampling_method == "top_p": |
|
assert sampling_value is not None |
|
|
|
|
|
|
|
logits = logits.astype(np.float32) |
|
|
|
logits /= temperature |
|
|
|
probs = np.exp(logits) / np.sum(np.exp(logits)) |
|
|
|
sorted_probs = np.sort(probs)[:, ::-1] |
|
sorted_indices = np.argsort(probs)[:, ::-1] |
|
|
|
|
|
if sampling_method == "top_k": |
|
index_of_interest = int(sampling_value) |
|
elif sampling_method == "top_p": |
|
p = sampling_value |
|
cumulative_probs = np.cumsum(sorted_probs, axis=-1) |
|
|
|
for index_of_interest, cumulative_prob in enumerate( |
|
cumulative_probs[0] |
|
): |
|
if cumulative_prob > p: |
|
break |
|
|
|
probs_of_interest = sorted_probs[:, : index_of_interest + 1] |
|
indices_of_interest = sorted_indices[:, : index_of_interest + 1] |
|
|
|
probs_of_interest /= np.sum(probs_of_interest) |
|
next_token = np.array( |
|
[np.random.choice(indices_of_interest[0], p=probs_of_interest[0])] |
|
) |
|
else: |
|
raise Exception(f"Unknown sampling method {sampling_method}") |
|
|
|
return next_token |
|
|
|
def greedy_search( |
|
self, |
|
input_ids, |
|
model, |
|
tokenizer, |
|
stop_words: list, |
|
max_length: int, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
top_k: int = 25, |
|
): |
|
generated_tokens = [] |
|
pos = np.array(0) |
|
|
|
x = ( |
|
self.embeddingLayer(torch.tensor(input_ids)) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
.astype(self.data_type) |
|
) |
|
|
|
for i in range(max_length): |
|
results = self.llm_session.run( |
|
None, |
|
{ |
|
"x": x, |
|
"attn_mask": self.attn_mask, |
|
"k_cache": self.k_cache[:, :, :pos], |
|
"v_cache": self.v_cache[:, :, :pos], |
|
"pos": pos.astype(np.int64), |
|
}, |
|
) |
|
logits, k_out, v_out = results[:3] |
|
|
|
next_token = self.sample_logits(logits, "top_p", top_p, temperature) |
|
next_token = next_token.reshape(1, -1) |
|
|
|
|
|
if next_token[0] == tokenizer.eos_id: |
|
del logits |
|
gc.collect() |
|
return |
|
|
|
input_ids = torch.cat((input_ids, torch.tensor(next_token)), dim=-1) |
|
|
|
generated_tokens.append(next_token[0].item()) |
|
text = tokenizer.decode(generated_tokens) |
|
|
|
seq_len = x.shape[1] |
|
self.k_cache[:, :, pos : pos + seq_len] = k_out |
|
self.v_cache[:, :, pos : pos + seq_len] = v_out |
|
pos = np.array(int(pos) + seq_len) |
|
|
|
x = ( |
|
self.embeddingLayer(torch.tensor(next_token)) |
|
.unsqueeze(0) |
|
.reshape([1, 1, self.hidden_size]) |
|
.cpu() |
|
.detach() |
|
.numpy() |
|
.astype(self.data_type) |
|
) |
|
|
|
yield text |
|
|
|
if any([x in text for x in stop_words]): |
|
del logits |
|
gc.collect() |
|
return |
|
|
|
def predict( |
|
self, |
|
text, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
): |
|
if text == "": |
|
yield chatbot, history, "Empty context." |
|
return |
|
try: |
|
self.llm_session |
|
except (ValueError, RuntimeError, TypeError): |
|
yield [[text, "No Model Found"]], [], "No Model Found" |
|
return |
|
|
|
inputs = self.generate_prompt_with_history( |
|
text, history, self.tokenizer, max_length=max_context_length_tokens |
|
) |
|
|
|
if inputs is None: |
|
yield chatbot, history, "Input too long." |
|
return |
|
else: |
|
prompt, inputs = inputs |
|
|
|
input_ids = inputs[:, -max_context_length_tokens:] |
|
|
|
|
|
self.total_count += 1 |
|
print(self.total_count) |
|
|
|
self.head_dim = int(self.hidden_size / self.n_heads) |
|
self.k_cache = np.zeros( |
|
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim], |
|
dtype=self.data_type, |
|
) |
|
self.v_cache = np.zeros( |
|
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim], |
|
dtype=self.data_type, |
|
) |
|
|
|
x = input_ids |
|
|
|
for x in self.greedy_search( |
|
input_ids, |
|
self.llm_session, |
|
self.tokenizer, |
|
stop_words=["[|Human|]", "[|AI|]"], |
|
max_length=max_length_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
if is_stop_word_or_prefix(x, ["[|Human|]", "[|AI|]"]) is False: |
|
if "[|Human|]" in x: |
|
x = x[: x.index("[|Human|]")].strip() |
|
if "[|AI|]" in x: |
|
x = x[: x.index("[|AI|]")].strip() |
|
x = x.strip() |
|
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [ |
|
[text, convert_to_markdown(x)] |
|
], history + [[text, x]] |
|
yield a, b, "Generating..." |
|
if shared_state.interrupted: |
|
shared_state.recover() |
|
try: |
|
yield a, b, "Stop: Success" |
|
return |
|
except Exception as e: |
|
print(type(e).__name__, e) |
|
pass |
|
|
|
del input_ids |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
try: |
|
yield a, b, "Generate: Success" |
|
except Exception as e: |
|
print(type(e).__name__, e) |
|
pass |
|
|
|
return |
|
|
|
def retry( |
|
self, |
|
text, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
): |
|
logging.info("Retry...") |
|
if len(history) == 0: |
|
yield chatbot, history, "Empty context" |
|
return |
|
chatbot.pop() |
|
inputs = history.pop()[0] |
|
for x in self.predict( |
|
inputs, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
): |
|
yield x |
|
|