import gradio as gr import os from typing import List import logging import urllib.request from utils import model_name_mapping, urial_template, openai_base_request, chat_template, openai_chat_request from constant import js_code_label, my_css, HEADER_MD, BASE_TO_ALIGNED, MODELS from openai import OpenAI import datetime # add logging info to console logging.basicConfig(level=logging.INFO) URIAL_VERSION = "inst_1k_v4.help" URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt" urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8') urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ``` STOP_STRS = ['"""', '# Query:', '# Answer:'] addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() models = MODELS # mega_hist = { # "base": [], # "aligned": [] # } def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, rp, model_name, model_type, api_key, request:gr.Request ): global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter assert model_type in ["base", "aligned"] # if history: # if model_type == "base": # mega_hist["base"] = history # else: # mega_hist["aligned"] = history if model_type == "base": prompt = urial_template(urial_prompt, history, message) else: messages = chat_template(history, message) # _model_name = "meta-llama/Llama-3-8b-hf" _model_name = model_name_mapping(model_name) if api_key and len(api_key) == 64: api_key = api_key else: api_key = None # headers = request.headers # if already 24 hours passed, reset the counter if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1): addr_limit_counter = {} LAST_UPDATE_TIME = datetime.datetime.now() host_addr = request.client.host if host_addr not in addr_limit_counter: addr_limit_counter[host_addr] = 0 if addr_limit_counter[host_addr] > 100: return "You have reached the limit of 100 requests for today. Please use your own API key." if model_type == "base": infer_request = openai_base_request(prompt=prompt, model=_model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, repetition_penalty=rp, stop=STOP_STRS, api_key=api_key) else: infer_request = openai_chat_request(messages=messages, model=_model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, repetition_penalty=rp, stop=STOP_STRS, api_key=api_key) addr_limit_counter[host_addr] += 1 logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}") logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};") response = "" for msg in infer_request: # print(msg.choices[0].delta.keys()) if hasattr(msg.choices[0], "delta"): # Note: 'ChoiceDelta' object may or may not be not subscriptable if "content" in msg.choices[0].delta: token = msg.choices[0].delta["content"] else: token = msg.choices[0].delta.content else: token = msg.choices[0].text if model_type == "base": should_stop = False for _stop in STOP_STRS: if _stop in response + token: should_stop = True break if should_stop: break if token is None: continue response += token if model_type == "base": if response.endswith('\n"'): response = response[:-1] elif response.endswith('\n""'): response = response[:-2] yield history + [(message, response)] # mega_hist[model_type].append((message, response)) # yield mega_hist[model_type] def load_models(base_model_name): print(f"base_model_name={base_model_name}") out_box = [gr.Chatbot(), gr.Chatbot(), gr.Dropdown()] out_box[0] = (gr.update(label=f"Chat with Base LLM: {base_model_name}")) aligned_model_name = BASE_TO_ALIGNED[base_model_name] out_box[1] = (gr.update(label=f"Chat with Aligned LLM: {aligned_model_name}")) out_box[2] = (gr.update(value=aligned_model_name, interactive=False)) return out_box[0], out_box[1], out_box[2] def clear_fn(): # mega_hist["base"] = [] # mega_hist["aligned"] = [] return None, None, None with gr.Blocks(gr.themes.Soft(), js=js_code_label, css=my_css) as demo: api_key = gr.Textbox(label="πŸ”‘ APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False) gr.Markdown(HEADER_MD) with gr.Row(): chat_a = gr.Chatbot(height=500, label="Chat with Base LLMs via URIAL") chat_b = gr.Chatbot(height=500, label="Chat with Aligned LLMs") with gr.Group(): with gr.Row(): with gr.Column(scale=1.5): message = gr.Textbox(label="Prompt", placeholder="Enter your message here") with gr.Row(): with gr.Column(scale=2): with gr.Row(): left_model_choice = gr.Dropdown(label="Base Model", choices=models, interactive=True) right_model_choice = gr.Textbox(label="Aligned Model", placeholder="xxx", visible=True) with gr.Row(): btn = gr.Button("πŸš€ Chat") # gr.Markdown("---") with gr.Row(): stop_btn = gr.Button("⏸️ Stop") clear_btn = gr.Button("πŸ” Clear") with gr.Row(): gr.Markdown(">> - We thank for the support of Llama-3.1-405B from [Hyperbolic AI](https://hyperbolic.xyz/). ") with gr.Column(scale=1): with gr.Accordion("βš™οΈ Params for **Base** LLM", open=True): with gr.Row(): max_tokens_1 = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True) temperature_1 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9) with gr.Row(): top_p_1 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9) rp_1 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.1) with gr.Accordion("βš™οΈ Params for **Aligned** LLM", open=True): with gr.Row(): max_tokens_2 = gr.Slider(label="Max tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True) temperature_2 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9) with gr.Row(): top_p_2 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9) rp_2 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0) left_model_choice.value = "Llama-3.1-405B-FP8" right_model_choice.value = "Llama-3.1-405B-Instruct-BF16" left_model_choice.change(load_models, [left_model_choice], [chat_a, chat_b, right_model_choice]) model_type_left = gr.Textbox(visible=False, value="base") model_type_right = gr.Textbox(visible=False, value="aligned") go1 = btn.click(respond, [message, chat_a, max_tokens_1, temperature_1, top_p_1, rp_1, left_model_choice, model_type_left, api_key], chat_a) go2 = btn.click(respond, [message, chat_b, max_tokens_2, temperature_2, top_p_2, rp_2, right_model_choice, model_type_right, api_key], chat_b) stop_btn.click(None, None, None, cancels=[go1, go2]) clear_btn.click(clear_fn, None, [message, chat_a, chat_b]) if __name__ == "__main__": demo.launch(show_api=False)