import os from typing import Iterator from text_generation import Client _MODEL_ID = 'mistralai/Mistral-7B-Instruct-v0.1' API_URL = "https://api-inference.huggingface.co/models/" + _MODEL_ID HF_TOKEN = os.environ.get("HF_READ_TOKEN", None) client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) EOS_STRING = "" EOT_STRING = "" def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str: texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] # The first user input is _not_ stripped do_strip = False for user_input, response in chat_history: user_input = user_input.strip() if do_strip else user_input do_strip = True texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') message = message.strip() if do_strip else message texts.append(f'{message} [/INST]') return ''.join(texts) def run(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.1, top_p: float = 0.9, top_k: int = 50) -> Iterator[str]: prompt = get_prompt(message, chat_history, system_prompt) generate_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, ) stream = client.generate_stream(prompt, **generate_kwargs) output = "" for response in stream: if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]): return output else: output += response.token.text yield output return output