Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import requests | |
import config as cfg | |
DEFAULT_SKILL = ["A", "B"] | |
ENVS = ["canary", "production"] | |
MODELS = ["chat-gpt", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-4-preview", | |
"gpt-4", "gpt-4-turbo", "general-reply-model", "gpt-4o", "gemini-1.5-flash-001", | |
"gemini-1.5-pro-001", "gemini-1.5-flash-002", "gemini-1.5-pro-002", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-05-13", "claude-3-5-sonnet@20240620"] | |
DEFAULT_MODEL = "chat-gpt" | |
MARKDOWN = """## Skill Repository | |
How to Use: | |
1. Click 'Get All skills' to get list of skill. | |
2. Choose specific skill that need to be updated | |
3. Click Get Skill Data to get skill information related prompt and config | |
4. Edit prompt and config, then click update skill to save and update the skill | |
""" | |
def edit_persona(skill_name, skprompt, config_json, env, update_to_prod: bool = False): | |
config_dict = validate_config_json(config_json) | |
payload = { | |
"prompt": skprompt, | |
"config": config_dict | |
} | |
URL = cfg.URL_SKILL + "/{}?env={}".format(skill_name, env) | |
response = requests.put(URL, data=json.dumps(payload)) | |
if update_to_prod and env == "canary": | |
# update prompt & config into production env | |
URL = cfg.URL_SKILL + "/{}?env=production".format(skill_name) | |
response = requests.put(URL, data=json.dumps(payload)) | |
if response.status_code == 200: | |
return "Success" | |
else: | |
raise gr.Error("Failed to update skill, please try again") | |
def get_all_skills(): | |
response = requests.get(cfg.URL_SKILLS) | |
if response.status_code == 200: | |
result = response.json()["data"]["skills"] | |
else: | |
result = [] | |
return gr.Dropdown(choices=result) | |
def get_skill_data(skill_name, env): | |
URL = cfg.URL_SKILL + "/{}?env={}".format(skill_name, env) | |
prompt, config_json = "", "" | |
response = requests.get(URL) | |
if response.status_code == 200: | |
data = response.json()["data"] | |
prompt = data["prompt"] | |
config_json = json.dumps(data["config"], indent=4) | |
return prompt, config_json | |
def validate_config_json(config_json): | |
try: | |
config_dict = json.loads(config_json) | |
except Exception as e: | |
raise gr.Error(e) | |
models = config_dict.get("default_services", [DEFAULT_MODEL]) | |
temp = config_dict.get("completion", {}).get("temperature") | |
max_tokens = config_dict.get("completion", {}).get("max_tokens", 100) | |
if len(models) > 0: | |
if models[0] not in MODELS: | |
raise gr.Error("the model not in the listed") | |
if temp < 0 or temp > 1: | |
raise gr.Error("temperature must be in [0, 1]") | |
if max_tokens < 0: | |
raise gr.Error("max_token must be greater than 0") | |
return config_dict | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
skills = gr.Button("1. Get All skills") | |
with gr.Row(): | |
with gr.Column(scale=5): | |
env = gr.Radio(ENVS, value=ENVS[0], label="Environment") | |
with gr.Column(scale=5): | |
skill_name = gr.Dropdown(DEFAULT_SKILL, label="Skill") | |
skill_data = gr.Button("2. Get Skill Data") | |
prompt = gr.Textbox(placeholder="You are helpful assistant", lines=10, label="skprompt") | |
config_json = gr.Textbox(placeholder="Config Json", lines=5, label="config json") | |
with gr.Row(): | |
with gr.Column(scale=5): | |
update_to_prod = gr.Radio([True, False], value=False, label="Update to production") | |
update_skill = gr.Button("3. Save and Update Skill") | |
with gr.Column(scale=5): | |
status_box = gr.Textbox(label="Status code from server", ) | |
skills.click(get_all_skills, [], [skill_name], concurrency_limit=cfg.QUEUE_CONCURENCY_COUNT) | |
skill_data.click(get_skill_data, [skill_name, env], [prompt, config_json], concurrency_limit=cfg.QUEUE_CONCURENCY_COUNT) | |
update_skill.click(edit_persona, [skill_name, prompt, config_json, env, update_to_prod], [status_box], concurrency_limit=cfg.QUEUE_CONCURENCY_COUNT) | |
( | |
demo | |
.queue(max_size=cfg.QUEUE_MAX_SIZE) | |
.launch(auth=(cfg.USERNAME, cfg.PASSWORD), debug=True) | |
) | |