|
import asyncio |
|
import json |
|
import Agently |
|
import gradio as gr |
|
|
|
model_dict = { |
|
"GPT 3.5": { |
|
"model_serial": "OAIClient", |
|
"options": { "model": "gpt-3.5-turbo" }, |
|
}, |
|
"GPT 4": { |
|
"model_serial": "OpenAI", |
|
"options": { "model": "gpt-4" }, |
|
}, |
|
"Claude 3": { |
|
"model_serial": "Claude", |
|
"options": { "model": "claude-3-opus-20240229" }, |
|
}, |
|
"Gemini Pro": { |
|
"model_serial": "Google", |
|
"options": {}, |
|
}, |
|
"Baidu Ernie 4.0": { |
|
"model_serial": "ERNIE", |
|
"options": { "model": "ernie-4.0" }, |
|
}, |
|
"Zhipu GLM 4": { |
|
"model_serial": "ZhipuAI", |
|
"options": { "model": "glm-4" }, |
|
}, |
|
"Kimi": { |
|
"model_serial": "OAIClient", |
|
"url": "https://api.moonshot.cn/v1", |
|
"options": { "model": "moonshot-v1-8k" }, |
|
}, |
|
} |
|
|
|
async def chat(message, history, model, base_url, auth, agent_id, session_id): |
|
if not auth or not model: |
|
yield "Welcome! First time open please add your model settings in \"Additional Inputs\" down below." |
|
return |
|
if not session_id: |
|
session_id = "$AGENTLY_GRADIO_SESSION" |
|
model_info = model_dict[model].copy() |
|
model_serial = model_info["model_serial"] |
|
del model_info["model_serial"] |
|
agent = ( |
|
Agently.create_agent(agent_id) |
|
.set_settings("current_model", model_serial) |
|
.set_settings("is_debug", False) |
|
) |
|
for setting_name, setting_value in model_info.items(): |
|
agent.set_settings(f"model.{ model_serial }.{setting_name}", setting_value) |
|
if model_serial == "ERNIE": |
|
agent.set_settings(f"model.{ model_serial }.auth", { "aistudio": auth }) |
|
else: |
|
agent.set_settings(f"model.{ model_serial }.auth", { "api_key": auth }) |
|
if base_url and base_url != "": |
|
agent.set_settings(f"model.{ model_serial }.url", base_url) |
|
reply_queue = asyncio.Queue() |
|
async def wait_to_yield(): |
|
buffer = "" |
|
yield buffer |
|
while True: |
|
reply = await reply_queue.get() |
|
buffer += reply |
|
if reply == "$STOP": |
|
break |
|
yield buffer |
|
yield_gen = wait_to_yield() |
|
asyncio.ensure_future(yield_gen.__anext__()) |
|
async def start_agent(): |
|
nonlocal session_id, message, history |
|
agent.active_session(session_id) |
|
if message == "#erase": |
|
agent.rewrite_chat_history([]) |
|
agent.stop_session() |
|
await reply_queue.put("已经重置对话") |
|
await reply_queue.put("$STOP") |
|
else: |
|
if len(history) == 0 and session_id == "$AGENTLY_GRADIO_SESSION": |
|
agent.rewrite_chat_history([]) |
|
agent.save_session() |
|
@agent.on_event("delta") |
|
async def delta_handler(data): |
|
await reply_queue.put(data) |
|
try: |
|
( |
|
await agent |
|
.use_public_tools(["search", "browse"]) |
|
|
|
.input(message) |
|
.start_async() |
|
) |
|
except Exception as e: |
|
await reply_queue.put(f"Error: { str(e) }") |
|
await reply_queue.put("$STOP") |
|
agent.stop_session() |
|
await reply_queue.put("$STOP") |
|
start_agent_task = asyncio.create_task(start_agent()) |
|
while True: |
|
try: |
|
wait_to_yield_task = asyncio.ensure_future(yield_gen.__anext__()) |
|
value = await wait_to_yield_task |
|
if value == "$STOP": |
|
return |
|
else: |
|
yield value |
|
except StopAsyncIteration: |
|
return |
|
|
|
iface = gr.ChatInterface( |
|
chat, |
|
title="Agently Gradio Chat Interface", |
|
retry_btn=None, |
|
undo_btn=None, |
|
clear_btn=None, |
|
additional_inputs=[ |
|
gr.Radio(model_dict.keys(), value="GPT 3.5", label="Choose Your Model"), |
|
gr.Textbox(placeholder="Input API Base URL (use default URL leave it empty)", label="Base URL"), |
|
gr.Textbox(placeholder="Input API-KEY or Access-Token", label="API-KEY"), |
|
gr.Textbox(value="demo_agent", label="Agent ID"), |
|
gr.Textbox(placeholder="Input any identity string to save the chat history", label="Session ID"), |
|
] |
|
) |
|
|
|
iface.launch() |
|
|