import openai import gradio as gr import os from typing import Any, Dict, Generator, List from huggingface_hub import InferenceClient from transformers import AutoTokenizer HF_TOKEN = os.getenv("HF_TOKEN") hf_models = { "mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", "llama 3": "meta-llama/Meta-Llama-3-70B-Instruct", } openai_models = {"gpt-4o","gpt-3.5-turbo-0125"} tokenizers = {k: AutoTokenizer.from_pretrained(m) for k,m in hf_models.items()} clients = {k: InferenceClient( m, token=HF_TOKEN ) for k,m in hf_models.items()} HF_GENERATE_KWARGS = { 'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), 'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), 'top_p': float(os.getenv("TOP_P", 0.6)), 'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)), 'do_sample': bool(os.getenv("DO_SAMPLE", True)) } OAI_GENERATE_KWARGS = { 'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), 'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), 'top_p': float(os.getenv("TOP_P", 0.6)), 'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2)) } def format_prompt(message: str, model: str): """ Formats the given message using a chat template. Args: message (str): The user message to be formatted. api_kind (str): LLM API provider. Returns: str: Formatted message after applying the chat template. """ # Create a list of message dictionaries with role and content messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}] if model in openai_models: return messages elif model in hf_models: return tokenizers[model].apply_chat_template(messages, tokenize=False) else: raise ValueError(f"Model {model} is not supported") def generate_hf(model: str, prompt: str, history: str, _: str) -> Generator[str, None, str]: """ Generate a sequence of tokens based on a given prompt and history using Mistral client. Args: prompt (str): The prompt for the text generation. history (str): Context or history for the text generation. Returns: Generator[str, None, str]: A generator yielding chunks of generated text. Returns a final string if an error occurs. """ formatted_prompt = format_prompt(prompt, model) formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8") try: stream = clients[model].text_generation( formatted_prompt, **HF_GENERATE_KWARGS, stream=True, details=True, return_full_text=False ) output = "" for response in stream: output += response.token.text yield output except Exception as e: if "Too Many Requests" in str(e): raise gr.Error(f"Too many requests: {str(e)}") elif "Authorization header is invalid" in str(e): raise gr.Error("Authentication error: HF token was either not provided or incorrect") else: raise gr.Error(f"Unhandled Exception: {str(e)}") def generate_openai(model: str, prompt: str, history: str, api_key: str) -> Generator[str, None, str]: """ Generate a sequence of tokens based on a given prompt and history using Mistral client. Args: prompt (str): The initial prompt for the text generation. history (str): Context or history for the text generation. Returns: Generator[str, None, str]: A generator yielding chunks of generated text. Returns a final string if an error occurs. """ formatted_prompt = format_prompt(prompt, model) client = openai.Client(api_key=api_key) try: stream = client.chat.completions.create( model=model, messages=formatted_prompt, **OAI_GENERATE_KWARGS, stream=True ) output = "" for chunk in stream: if chunk.choices[0].delta.content: output += chunk.choices[0].delta.content yield output except Exception as e: if "Too Many Requests" in str(e): raise gr.Error("ERROR: Too many requests on OpenAI client") elif "You didn't provide an API key" in str(e): raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect") else: raise gr.Error(f"Unhandled Exception: {str(e)}")