try-this-model / app.py
wxgeorge's picture
:recycle: refactor larger model whitelisting.
3fa9161
raw
history blame
5.91 kB
from openai import OpenAI
import gradio as gr
import os
import json
import html
api_key = os.environ.get('FEATHERLESS_API_KEY')
if not api_key:
raise RuntimeError("Cannot start without required API key. Please register for one at https://featherless.ai")
client = OpenAI(
base_url="https://api.featherless.ai/v1",
api_key=api_key
)
REFLECTION_SYSTEM_PROMPT = """You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""
def respond(message, history, model):
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human })
history_openai_format.append({"role": "assistant", "content":assistant})
history_openai_format.append({"role": "user", "content": message})
if model == "mattshumer/Reflection-Llama-3.1-70B":
history_openai_format = [
{"role": "system", "content": REFLECTION_SYSTEM_PROMPT},
*history_openai_format
]
response = client.chat.completions.create(
model=model,
messages= history_openai_format,
temperature=1.0,
stream=True,
max_tokens=2000,
extra_headers={
'HTTP-Referer': 'https://huggingface.co/spaces/featherless-ai/try-this-model',
'X-Title': "HF's missing inference widget"
}
)
partial_message = ""
for chunk in response:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
escaped_content = html.escape(content)
partial_message += escaped_content
yield partial_message
logo = open('./logo.svg').read()
with open('./model-cache.json', 'r') as f_model_cache:
model_cache = json.load(f_model_cache)
model_class_from_model_id = { model_id: model_class for model_class, model_ids in model_cache.items() for model_id in model_ids }
model_class_filter = {
"mistral-v02-7b-std-lc": True,
"llama3-8b-8k": True,
"llama31-8b-16k": True,
"llama2-solar-10b7-4k": True,
"mistral-nemo-12b-lc": True,
"llama2-13b-4k": True,
"llama3-15b-8k": True,
"qwen2-32b-lc":False,
"llama3-70b-8k":False,
"llama31-70b-16k": False,
"qwen2-72b-lc":False,
"mixtral-8x22b-lc":False,
"llama3-405b-lc":False,
}
# we run a few other models here as well
REFLECTION="mattshumer/Reflection-Llama-3.1-70B"
QWEN25_72B="Qwen/Qwen2.5-72B"
bigger_whitelisted_models = [
REFLECTION,
QWEN25_72B
]
# REFLECTION is in backup hosting
model_class_from_model_id[REFLECTION] = 'llama31-70b-16k'
def build_model_choices():
all_choices = []
for model_class in model_cache:
if model_class not in model_class_filter:
print(f"Warning: new model class {model_class}. Treating as blacklisted")
continue
if not model_class_filter[model_class]:
continue
all_choices += [ (f"{model_id} ({model_class})", model_id) for model_id in model_cache[model_class] ]
all_choices += [ (f"{model_id}, {model_class_from_model_id[model_id]}", model_id) for model_id in bigger_whitelisted_models ]
return all_choices
model_choices = build_model_choices()
def initial_model(referer=None):
return "Qwen/Qwen2.5-72B"
# if referer == 'http://127.0.0.1:7860/':
# return 'Sao10K/Venomia-1.1-m7'
# if referer and referer.startswith("https://huggingface.co/"):
# possible_model = referer[23:]
# full_model_list = functools.reduce(lambda x,y: x+y, model_cache.values(), [])
# model_is_supported = possible_model in full_model_list
# if model_is_supported:
# return possible_model
# # let's use a random but different model each day.
# key=os.environ.get('RANDOM_SEED', 'kcOtfNHA+e')
# o = random.Random(f"{key}-{datetime.date.today().strftime('%Y-%m-%d')}")
# return o.choice(model_choices)[1]
title_text="HuggingFace's missing inference widget"
css = """
.logo-mark { fill: #ffe184; }
/* from https://github.com/gradio-app/gradio/issues/4001
* necessary as putting ChatInterface in gr.Blocks changes behaviour
*/
.contain { display: flex; flex-direction: column; }
.gradio-container { height: 100vh !important; }
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; overflow: auto;}
"""
with gr.Blocks(title_text, css=css) as demo:
gr.HTML("""
<h1 align="center">HuggingFace's missing inference widget</h1>
<h2 align="center">
Please select your model from the list 👇
</h2>
""")
# hidden_state = gr.State(value=initial_model)
with gr.Row():
model_selector = gr.Dropdown(
label="Select your Model",
choices=build_model_choices(),
value=initial_model,
# value=hidden_state,
scale=4
)
gr.Button(
value="Visit Model Card ↗️",
scale=1
).click(
inputs=[model_selector],
js="(model_selection) => { window.open(`https://huggingface.co/${model_selection}`, '_blank') }",
fn=None,
)
gr.ChatInterface(
respond,
additional_inputs=[model_selector],
head=""",
<script>console.log("Hello from gradio!")</script>
""",
concurrency_limit=5
)
gr.HTML(f"""
<p align="center">
Inference by <a href="https://featherless.ai">{logo}</a>
</p>
""")
def update_initial_model_choice(request: gr.Request):
return initial_model(request.headers.get('referer'))
demo.load(update_initial_model_choice, outputs=model_selector)
demo.launch()