import logging from transformers import AutoTokenizer, AutoModelForCausalLM logger = logging.getLogger(__name__) class Chat: def __init__( self, path="mathewhe/Llama-3.1-8B-Chat", device="cuda", ): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path, device_map=device) self.messages = list() self.device = device self.gen_kwargs = { "min_new_tokens": 1, "max_new_tokens": 2048, "top_p": 0.8, "temperature": 0.8, "do_sample": True, "repetition_penalty": 1.1, } def reset(self): r"""Reset the chat message history.""" self.messages = list() def _inference(self, messages): chat = self.tokenizer.apply_chat_template(messages, tokenize=False) inputs = { k: v.to(self.device) for k, v in self.tokenizer(chat, return_tensors="pt", add_special_tokens=False).items() } input_length = len(inputs["input_ids"][0]) output = self.model.generate(**inputs, **self.gen_kwargs) response = self.tokenizer.decode( output[0].tolist()[input_length:], skip_special_tokens=True, ) return response def message(self, message): r""" Add the message to the chat history and return a response. """ self.messages.append({"role": "user", "content": message}) # need to add caching of internal state!! response = self._inference(self.messages) self.messages.append({"role": "assistant", "content": response}) return response def cli_chat(self): r""" For CLI-based chatting (with history). """ asst_prompt = "Assistant: " user_prompt = "---> User: " print(f"{asst_prompt}Hi! How can I help you?\n") message = input(user_prompt) while not (message is None or message == ""): response = self.message(message) print(f"\n{asst_prompt}{response}\n") message = input(user_prompt) def instruct(self, message): r""" For single instruction-response interactions (without history). """ messages = [{"role": "user", "content": message}] response = self._inference(messages) return response if __name__ == "__main__": chat = Chat() chat.cli_chat()