import gradio as gr import os from typing import List import logging import urllib.request from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY from constant import js_code_label, HEADER_MD # add logging info to console logging.basicConfig(level=logging.INFO) URIAL_VERSION = "inst_1k_v4.help" URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt" urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8') urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ``` STOP_STRS = ['"""', '# Query:', '# Answer:'] def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, rp, model_name, together_api_key ): global STOP_STRS, urial_prompt rp = 1.0 prompt = urial_template(urial_prompt, history, message) # _model_name = "meta-llama/Llama-3-8b-hf" _model_name = model_name_mapping(model_name) if together_api_key and len(together_api_key) == 64: api_key = together_api_key else: api_key = DEFAULT_API_KEY request = openai_base_request(prompt=prompt, model=_model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, repetition_penalty=rp, stop=STOP_STRS, api_key=api_key) response = "" for msg in request: # print(msg.choices[0].delta.keys()) token = msg.choices[0].delta["content"] should_stop = False for _stop in STOP_STRS: if _stop in response + token: should_stop = True break if should_stop: break response += token if response.endswith('\n"'): response = response[:-1] elif response.endswith('\n""'): response = response[:-2] yield response with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo: with gr.Row(): with gr.Column(): gr.Markdown(HEADER_MD) model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", "Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"] , value="Llama-3-8B", label="Base LLM name") with gr.Column(): together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key") with gr.Column(): with gr.Row(): max_tokens = gr.Textbox(value=256, label="Max tokens") temperature = gr.Textbox(value=0.5, label="Temperature") top_p = gr.Textbox(value=0.9, label="Top-p") rp = gr.Textbox(value=1.1, label="Repetition penalty") chat = gr.ChatInterface( respond, additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key], # additional_inputs_accordion="⚙️ Parameters", # fill_height=True, ) chat.chatbot.height = 550 if __name__ == "__main__": demo.launch()