|
from threading import Thread |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
pipeline, |
|
TextIteratorStreamer, |
|
) |
|
import torch |
|
|
|
from LLM.chat import Chat |
|
from baseHandler import BaseHandler |
|
from rich.console import Console |
|
import logging |
|
from nltk import sent_tokenize |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
|
|
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { |
|
"en": "english", |
|
"fr": "french", |
|
"es": "spanish", |
|
"zh": "chinese", |
|
"ja": "japanese", |
|
"ko": "korean", |
|
} |
|
|
|
class LanguageModelHandler(BaseHandler): |
|
""" |
|
Handles the language model part. |
|
""" |
|
|
|
def setup( |
|
self, |
|
model_name="microsoft/Phi-3-mini-4k-instruct", |
|
device="cuda", |
|
torch_dtype="float16", |
|
gen_kwargs={}, |
|
user_role="user", |
|
chat_size=1, |
|
init_chat_role=None, |
|
init_chat_prompt="You are a helpful AI assistant.", |
|
): |
|
self.device = device |
|
self.torch_dtype = getattr(torch, torch_dtype) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, torch_dtype=torch_dtype, trust_remote_code=True |
|
).to(device) |
|
self.pipe = pipeline( |
|
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device |
|
) |
|
self.streamer = TextIteratorStreamer( |
|
self.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True, |
|
) |
|
self.gen_kwargs = { |
|
"streamer": self.streamer, |
|
"return_full_text": False, |
|
**gen_kwargs, |
|
} |
|
|
|
self.chat = Chat(chat_size) |
|
if init_chat_role: |
|
if not init_chat_prompt: |
|
raise ValueError( |
|
"An initial promt needs to be specified when setting init_chat_role." |
|
) |
|
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) |
|
self.user_role = user_role |
|
|
|
self.warmup() |
|
|
|
def warmup(self): |
|
logger.info(f"Warming up {self.__class__.__name__}") |
|
|
|
dummy_input_text = "Repeat the word 'home'." |
|
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] |
|
warmup_gen_kwargs = { |
|
"min_new_tokens": self.gen_kwargs["min_new_tokens"], |
|
"max_new_tokens": self.gen_kwargs["max_new_tokens"], |
|
**self.gen_kwargs, |
|
} |
|
|
|
n_steps = 2 |
|
|
|
if self.device == "cuda": |
|
start_event = torch.cuda.Event(enable_timing=True) |
|
end_event = torch.cuda.Event(enable_timing=True) |
|
torch.cuda.synchronize() |
|
start_event.record() |
|
|
|
for _ in range(n_steps): |
|
thread = Thread( |
|
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs |
|
) |
|
thread.start() |
|
for _ in self.streamer: |
|
pass |
|
|
|
if self.device == "cuda": |
|
end_event.record() |
|
torch.cuda.synchronize() |
|
|
|
logger.info( |
|
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" |
|
) |
|
|
|
def process(self, prompt): |
|
console.print("infering language model...") |
|
console.print(prompt) |
|
logger.debug("infering language model...") |
|
language_code = None |
|
if isinstance(prompt, tuple): |
|
prompt, language_code = prompt |
|
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt |
|
|
|
self.chat.append({"role": self.user_role, "content": prompt}) |
|
thread = Thread( |
|
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs |
|
) |
|
thread.start() |
|
if self.device == "mps": |
|
generated_text = "" |
|
for new_text in self.streamer: |
|
generated_text += new_text |
|
printable_text = generated_text |
|
torch.mps.empty_cache() |
|
else: |
|
generated_text, printable_text = "", "" |
|
for new_text in self.streamer: |
|
generated_text += new_text |
|
printable_text += new_text |
|
sentences = sent_tokenize(printable_text) |
|
if len(sentences) > 1: |
|
yield (sentences[0], language_code) |
|
printable_text = new_text |
|
|
|
self.chat.append({"role": "assistant", "content": generated_text}) |
|
|
|
|
|
yield (printable_text, language_code) |
|
|