|
import logging |
|
from LLM.chat import Chat |
|
from baseHandler import BaseHandler |
|
from mlx_lm import load, stream_generate, generate |
|
from rich.console import Console |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { |
|
"en": "english", |
|
"fr": "french", |
|
"es": "spanish", |
|
"zh": "chinese", |
|
"ja": "japanese", |
|
"ko": "korean", |
|
} |
|
|
|
class MLXLanguageModelHandler(BaseHandler): |
|
""" |
|
Handles the language model part. |
|
""" |
|
|
|
def setup( |
|
self, |
|
model_name="microsoft/Phi-3-mini-4k-instruct", |
|
device="mps", |
|
torch_dtype="float16", |
|
gen_kwargs={}, |
|
user_role="user", |
|
chat_size=1, |
|
init_chat_role=None, |
|
init_chat_prompt="You are a helpful AI assistant.", |
|
): |
|
self.model_name = model_name |
|
self.model, self.tokenizer = load(self.model_name) |
|
self.gen_kwargs = gen_kwargs |
|
|
|
self.chat = Chat(chat_size) |
|
if init_chat_role: |
|
if not init_chat_prompt: |
|
raise ValueError( |
|
"An initial promt needs to be specified when setting init_chat_role." |
|
) |
|
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) |
|
self.user_role = user_role |
|
|
|
self.warmup() |
|
|
|
def warmup(self): |
|
logger.info(f"Warming up {self.__class__.__name__}") |
|
|
|
dummy_input_text = "Repeat the word 'home'." |
|
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] |
|
|
|
n_steps = 2 |
|
|
|
for _ in range(n_steps): |
|
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) |
|
generate( |
|
self.model, |
|
self.tokenizer, |
|
prompt=prompt, |
|
max_tokens=self.gen_kwargs["max_new_tokens"], |
|
verbose=False, |
|
) |
|
|
|
def process(self, prompt): |
|
logger.debug("infering language model...") |
|
language_code = None |
|
|
|
if isinstance(prompt, tuple): |
|
prompt, language_code = prompt |
|
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt |
|
|
|
self.chat.append({"role": self.user_role, "content": prompt}) |
|
|
|
|
|
if "gemma" in self.model_name.lower(): |
|
chat_messages = [ |
|
msg for msg in self.chat.to_list() if msg["role"] != "system" |
|
] |
|
else: |
|
chat_messages = self.chat.to_list() |
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
chat_messages, tokenize=False, add_generation_prompt=True |
|
) |
|
output = "" |
|
curr_output = "" |
|
for t in stream_generate( |
|
self.model, |
|
self.tokenizer, |
|
prompt, |
|
max_tokens=self.gen_kwargs["max_new_tokens"], |
|
): |
|
output += t |
|
curr_output += t |
|
if curr_output.endswith((".", "?", "!", "<|end|>")): |
|
yield (curr_output.replace("<|end|>", ""), language_code) |
|
curr_output = "" |
|
generated_text = output.replace("<|end|>", "") |
|
torch.mps.empty_cache() |
|
|
|
self.chat.append({"role": "assistant", "content": generated_text}) |