Spaces:
Running
Running
File size: 3,378 Bytes
0231e6a 54f7da0 8df0f23 0231e6a 8df0f23 0231e6a 54f7da0 0231e6a 460e2a9 0231e6a 54f7da0 0231e6a 8df0f23 0231e6a 8df0f23 460e2a9 0231e6a 460e2a9 0231e6a 54f7da0 460e2a9 0231e6a 54f7da0 9e82682 0231e6a 9e82682 0231e6a 9e82682 0231e6a 8df0f23 a19b85a 0231e6a 9e82682 0231e6a 8359b58 0231e6a f358cdd 460e2a9 9e82682 460e2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request, DEFAULT_API_KEY
from constant import js_code_label, HEADER_MD
# add logging info to console
logging.basicConfig(level=logging.INFO)
URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
rp,
model_name,
together_api_key
):
global STOP_STRS, urial_prompt
rp = 1.0
prompt = urial_template(urial_prompt, history, message)
# _model_name = "meta-llama/Llama-3-8b-hf"
_model_name = model_name_mapping(model_name)
if together_api_key and len(together_api_key) == 64:
api_key = together_api_key
else:
api_key = DEFAULT_API_KEY
request = openai_base_request(prompt=prompt, model=_model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repetition_penalty=rp,
stop=STOP_STRS, api_key=api_key)
response = ""
for msg in request:
# print(msg.choices[0].delta.keys())
token = msg.choices[0].delta["content"]
should_stop = False
for _stop in STOP_STRS:
if _stop in response + token:
should_stop = True
break
if should_stop:
break
response += token
if response.endswith('\n"'):
response = response[:-1]
elif response.endswith('\n""'):
response = response[:-2]
yield response
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(HEADER_MD)
model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1",
"Mixtral-8x22B", "Yi-6B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
, value="Llama-3-8B", label="Base LLM name")
with gr.Column():
together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key")
with gr.Column():
with gr.Row():
max_tokens = gr.Textbox(value=256, label="Max tokens")
temperature = gr.Textbox(value=0.5, label="Temperature")
top_p = gr.Textbox(value=0.9, label="Top-p")
rp = gr.Textbox(value=1.1, label="Repetition penalty")
chat = gr.ChatInterface(
respond,
additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
# additional_inputs_accordion="⚙️ Parameters",
# fill_height=True,
)
chat.chatbot.height = 550
if __name__ == "__main__":
demo.launch() |