BaseChat / app.py
yuchenlin's picture
base chat
8df0f23
raw
history blame
3.38 kB
import gradio as gr
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
from constant import js_code_label, HEADER_MD
# add logging info to console
logging.basicConfig(level=logging.INFO)
URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
rp,
model_name,
together_api_key
):
global STOP_STRS, urial_prompt
rp = 1.0
prompt = urial_template(urial_prompt, history, message)
# _model_name = "meta-llama/Llama-3-8b-hf"
_model_name = model_name_mapping(model_name)
if together_api_key and len(together_api_key) == 64:
api_key = together_api_key
else:
api_key = DEFAULT_API_KEY
request = openai_base_request(prompt=prompt, model=_model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repetition_penalty=rp,
stop=STOP_STRS, api_key=api_key)
response = ""
for msg in request:
# print(msg.choices[0].delta.keys())
token = msg.choices[0].delta["content"]
should_stop = False
for _stop in STOP_STRS:
if _stop in response + token:
should_stop = True
break
if should_stop:
break
response += token
if response.endswith('\n"'):
response = response[:-1]
elif response.endswith('\n""'):
response = response[:-2]
yield response
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(HEADER_MD)
model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1",
"Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
, value="Llama-3-8B", label="Base LLM name")
with gr.Column():
together_api_key = gr.Textbox(label="πŸ”‘ Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key")
with gr.Column():
with gr.Row():
max_tokens = gr.Textbox(value=256, label="Max tokens")
temperature = gr.Textbox(value=0.5, label="Temperature")
top_p = gr.Textbox(value=0.9, label="Top-p")
rp = gr.Textbox(value=1.1, label="Repetition penalty")
chat = gr.ChatInterface(
respond,
additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
# additional_inputs_accordion="βš™οΈ Parameters",
# fill_height=True,
)
chat.chatbot.height = 550
if __name__ == "__main__":
demo.launch()