class ChatState(): """ Manages the conversation history for a turn-based chatbot Follows the turn-based conversation guidelines for the Gemma family of models documented at https://ai.google.dev/gemma/docs/formatting """ __START_TURN_USER__ = "Instruction:\n" __START_TURN_MODEL__ = "\n\nResponse:\n" __END_TURN__ = ""#"\n" def __init__(self, model, system=""): """ Initializes the chat state. Args: model: The language model to use for generating responses. system: (Optional) System instructions or bot description. """ self.model = model self.system = system self.history = [] def add_to_history_as_user(self, message): """ Adds a user message to the history with start/end turn markers. """ self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__) def add_to_history_as_model(self, message): """ Adds a model response to the history with the start turn marker. Model will generate end turn marker. """ self.history.append(self.__START_TURN_MODEL__ + message+ "\n") def get_history(self): """ Returns the entire chat history as a single string. """ return "".join([*self.history]) def get_full_prompt(self): """ Builds the prompt for the language model, including history and system description. """ prompt = self.get_history() + self.__START_TURN_MODEL__ if len(self.system)>0: prompt = self.system + "\n" + prompt return prompt def send_message(self, message): """ Handles sending a user message and getting a model response. Args: message: The user's message. Returns: The model's response. """ self.add_to_history_as_user(message) prompt = self.get_full_prompt() response = self.model.generate(prompt, max_length=4096) result = response.replace(prompt, "") # Extract only the new response self.add_to_history_as_model(result) return result