Spaces:
Sleeping
Sleeping
import os | |
import json | |
from time import sleep | |
import gradio as gr | |
import uvicorn | |
from datetime import datetime | |
from typing import List, Tuple | |
from starlette.config import Config | |
from starlette.middleware.sessions import SessionMiddleware | |
from starlette.responses import RedirectResponse | |
from authlib.integrations.starlette_client import OAuth, OAuthError | |
from fastapi import FastAPI, Request | |
from shared import Client, User, OAuthProvider | |
app = FastAPI() | |
config = {} | |
clients = {} | |
llm_host_names = [] | |
oauth = None | |
def init_oauth(): | |
global oauth | |
google_client_id = os.environ.get("GOOGLE_CLIENT_ID") | |
google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET") | |
secret_key = os.environ.get('SECRET_KEY') or "a_very_secret_key" | |
starlette_config = Config(environ={"GOOGLE_CLIENT_ID": google_client_id, | |
"GOOGLE_CLIENT_SECRET": google_client_secret}) | |
oauth = OAuth(starlette_config) | |
oauth.register( | |
name='google', | |
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', | |
client_kwargs={'scope': 'openid email profile'} | |
) | |
app.add_middleware(SessionMiddleware, secret_key=secret_key) | |
def init_config(): | |
""" | |
Initialize configuration. A configured `api_url` or `api_key` may be an | |
envvar reference OR a literal value. Configuration should follow the | |
format: | |
{"<llm_host_name>": {"api_key": "<api_key>", | |
"api_url": "<api_url>" | |
} | |
} | |
""" | |
global config | |
global clients | |
global llm_host_names | |
config = json.loads(os.environ['CONFIG']) | |
client_config = config.get("clients") or config | |
for name in client_config: | |
model_personas = client_config[name].get("personas", {}) | |
client = Client( | |
api_url=os.environ.get(client_config[name]['api_url'], | |
client_config[name]['api_url']), | |
api_key=os.environ.get(client_config[name]['api_key'], | |
client_config[name]['api_key']), | |
personas=model_personas | |
) | |
clients[name] = client | |
llm_host_names = list(client_config.keys()) | |
def get_allowed_models(user: User) -> List[str]: | |
""" | |
Get a list of allowed endpoints for a specified user domain. Allowed domains | |
are configured in each model's configuration and may optionally be overridden | |
in the Gradio demo configuration. | |
:param user: User to get permissions for | |
:return: List of allowed endpoints from configuration (including empty | |
strings for disallowed endpoints) | |
""" | |
overrides = config.get("permissions_override", {}) | |
allowed_endpoints = [] | |
for client in clients: | |
permission = overrides.get(client, | |
clients[client].config.inference.permissions) | |
if not permission: | |
# Permissions not specified (None or empty dict); model is public | |
allowed_endpoints.append(client) | |
elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \ | |
permission.get("google_domains", []): | |
# Google oauth domain is in the allowed domain list | |
allowed_endpoints.append(client) | |
else: | |
allowed_endpoints.append("") | |
print(f"No permission to access {client}") | |
return allowed_endpoints | |
def parse_radio_select(radio_select: tuple) -> (str, str): | |
""" | |
Parse radio selection to determine the requested model and persona | |
:param radio_select: List of radio selection states | |
:return: Selected model, persona | |
""" | |
value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None) | |
model = llm_host_names[value_index] | |
persona = radio_select[value_index] | |
return model, persona | |
def get_login_button(request: gr.Request) -> gr.Button: | |
""" | |
Get a login/logout button based on current login status | |
:param request: Gradio request to evaluate | |
:return: Button for either login or logout action | |
""" | |
user = get_user(request).username | |
print(f"Getting login button for {user}") | |
if user == "guest": | |
return gr.Button("Login", link="/login") | |
else: | |
return gr.Button(f"Logout {user}", link="/logout") | |
def get_user(request: Request) -> User: | |
""" | |
Get a unique user email address for the specified request | |
:param request: FastAPI Request object with user session data | |
:return: String user email address or "guest" | |
""" | |
# {'iss': 'https://accounts.google.com', | |
# 'azp': '***.apps.googleusercontent.com', | |
# 'aud': '***.apps.googleusercontent.com', | |
# 'sub': '###', | |
# 'hd': 'neon.ai', | |
# 'email': 'daniel@neon.ai', | |
# 'email_verified': True, | |
# 'at_hash': '***', | |
# 'nonce': '***', | |
# 'name': 'Daniel McKnight', | |
# 'picture': 'https://lh3.googleusercontent.com/a/***', | |
# 'given_name': '***', | |
# 'family_name': '***', | |
# 'iat': ###, | |
# 'exp': ###} | |
if not request: | |
return User(OAuthProvider.NONE, "guest", "") | |
user_dict = request.session.get("user", {}) | |
if user_dict.get("iss") == "https://accounts.google.com": | |
user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"]) | |
elif user_dict: | |
print(f"Unknown user session data: {user_dict}") | |
user = User(OAuthProvider.NONE, "guest", "") | |
else: | |
user = User(OAuthProvider.NONE, "guest", "") | |
print(user) | |
return user | |
async def logout(request: Request): | |
""" | |
Remove the user session context and reload an un-authenticated session | |
:param request: FastAPI Request object with user session data | |
:return: Redirect to `/` | |
""" | |
request.session.pop('user', None) | |
return RedirectResponse(url='/') | |
async def login(request: Request): | |
""" | |
Start oauth flow for login with Google | |
:param request: FastAPI Request object | |
""" | |
redirect_uri = request.url_for('auth') | |
# Ensure that the `redirect_uri` is https | |
from urllib.parse import urlparse, urlunparse | |
redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https')) | |
return await oauth.google.authorize_redirect(request, redirect_uri) | |
async def auth(request: Request): | |
""" | |
Callback endpoint for Google oauth | |
:param request: FastAPI Request object | |
""" | |
try: | |
access_token = await oauth.google.authorize_access_token(request) | |
except OAuthError: | |
return RedirectResponse(url='/') | |
request.session['user'] = dict(access_token)["userinfo"] | |
return RedirectResponse(url='/') | |
def respond( | |
message: str, | |
history: List[Tuple[str, str]], | |
conversational: bool, | |
max_tokens: int, | |
*radio_select, | |
): | |
""" | |
Send user input to a vLLM backend and return the generated response | |
:param message: String input from the user | |
:param history: Optional list of chat history (<user message>,<llm message>) | |
:param conversational: If true, include chat history | |
:param max_tokens: Maximum tokens for the LLM to generate | |
:param radio_select: List of radio selection args to parse | |
:return: String LLM response | |
""" | |
model, persona = parse_radio_select(radio_select) | |
client = clients[model] | |
messages = [] | |
try: | |
system_prompt = client.personas[persona] | |
except KeyError: | |
supported_personas = list(client.personas.keys()) | |
raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}") | |
if system_prompt is not None: | |
messages.append({"role": "system", "content": system_prompt}) | |
if conversational: | |
for val in history[-2:]: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
completion = client.openai.chat.completions.create( | |
model=client.vllm_model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=0, | |
extra_body={ | |
"add_special_tokens": True, | |
"repetition_penalty": 1.05, | |
"use_beam_search": True, | |
"best_of": 5, | |
}, | |
) | |
response = completion.choices[0].message.content | |
return response | |
def get_model_options(request: gr.Request) -> List[gr.Radio]: | |
""" | |
Get allowed models for the specified session. | |
:param request: Gradio request object to get user from | |
:return: List of Radio objects for available models | |
""" | |
if request: | |
# `user` is a valid Google email address or 'guest' | |
user = get_user(request.request) | |
else: | |
user = User(OAuthProvider.NONE, "guest", "") | |
print(f"Getting models for {user.username}") | |
allowed_llm_host_names = get_allowed_models(user) | |
radio_infos = [f"{name} ({clients[name].vllm_model_name})" | |
if name in clients else "Not Authorized" | |
for name in allowed_llm_host_names] | |
# Components | |
radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [], | |
value=None, label=info) for name, info | |
in zip(allowed_llm_host_names, radio_infos)] | |
# Select the first available option by default | |
radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0] | |
print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}") | |
# # Ensure we always have the same number of rows | |
# while len(radios) < len(llm_host_names): | |
# radios.append(gr.Radio(choices=[], value=None, label="Not Authorized")) | |
return radios | |
def init_gradio() -> gr.Blocks: | |
""" | |
Initialize a Gradio demo | |
:return: | |
""" | |
conversational_checkbox = gr.Checkbox(value=True, label="conversational") | |
max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, | |
label="Max new tokens") | |
radios = get_model_options(None) | |
with gr.Blocks() as blocks: | |
# Events | |
radio_state = gr.State([radio.value for radio in radios]) | |
def radio_click(state, *new_state): | |
""" | |
Handle any state changes that require re-rendering radio buttons | |
:param state: Previous radio state representation (before selection) | |
:param new_state: Current radio state (including selection) | |
:return: Desired new state (current option selected, previous option | |
deselected) | |
""" | |
# Login and model options are triggered on load. This sleep is just | |
# a hack to make sure those events run before this logic to select | |
# the default model | |
sleep(0.1) | |
try: | |
changed_index = next(i for i in range(len(state)) | |
if state[i] != new_state[i]) | |
changed_value = new_state[changed_index] | |
except StopIteration: | |
# TODO: This is the result of some error in rendering a selected | |
# option. | |
# Changed to current selection | |
changed_value = [i for i in new_state if i is not None][0] | |
changed_index = new_state.index(changed_value) | |
clean_state = [None if i != changed_index else changed_value | |
for i in range(len(state))] | |
return clean_state, *clean_state | |
# Compile | |
hf_config = config.get("huggingface_text") or dict() | |
accordion_info = hf_config.get("accordian_info") or \ | |
"Persona and LLM Options - Choose one:" | |
version = hf_config.get("version") or \ | |
f"v{datetime.now().strftime('%Y-%m-%d')}" | |
title = hf_config.get("title") or \ | |
f"Neon AI BrainForge Personas and Large Language Models ({version})" | |
with gr.Accordion(label=accordion_info, open=True, | |
render=False) as accordion: | |
[radio.render() for radio in radios] | |
conversational_checkbox.render() | |
max_tokens_slider.render() | |
_ = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
conversational_checkbox, | |
max_tokens_slider, | |
*radios, | |
], | |
additional_inputs_accordion=accordion, | |
title=title, | |
concurrency_limit=5, | |
) | |
# Render login/logout button | |
login_button = gr.Button("Log In") | |
blocks.load(get_login_button, None, login_button) | |
accordion.render() | |
blocks.load(get_model_options, None, radios) | |
return blocks | |
if __name__ == "__main__": | |
init_config() | |
init_oauth() | |
blocks = init_gradio() | |
app = gr.mount_gradio_app(app, blocks, '/') | |
uvicorn.run(app, host='0.0.0.0', port=7860) | |