Spaces:
Running
Running
import gradio as gr | |
import os | |
from huggingface_hub.file_download import http_get | |
from llama_cpp import Llama | |
SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." | |
# Define the directory dynamically | |
dir = "." | |
def get_message_tokens(model, role, content): | |
content = f"{role}\n{content}\n</s>" | |
content = content.encode("utf-8") | |
return model.tokenize(content, special=True) | |
def get_system_tokens(model): | |
system_message = {"role": "system", "content": SYSTEM_PROMPT} | |
return get_message_tokens(model, **system_message) | |
def load_model(directory, model_name, model_url): | |
final_model_path = os.path.join(directory, model_name) | |
print(f"Checking model: {model_name}") | |
if not os.path.exists(final_model_path): | |
print(f"Downloading model: {model_name}") | |
with open(final_model_path, "wb") as f: | |
http_get(model_url, f) | |
os.chmod(final_model_path, 0o777) | |
print(f"Model {model_name} ready!") | |
model = Llama(model_path=final_model_path, n_ctx=1024) | |
print(f"Model {model_name} loaded successfully!") | |
return model | |
MODEL_OPTIONS = { | |
"Apollo 0.5B": { | |
"directory": dir, | |
"model_name": "apollo-0.5b.gguf", | |
"model_url": "https://huggingface.co/path_to_apollo_0.5b_model" | |
}, | |
"Apollo 2B": { | |
"directory": dir, | |
"model_name": "apollo-2b.gguf", | |
"model_url": "https://huggingface.co/path_to_apollo_2b_model" | |
}, | |
"Apollo 7B": { | |
"directory": dir, | |
"model_name": "Apollo-7B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/resolve/main/Apollo-7B-q8_0.gguf" | |
}, | |
"Apollo2 0.5B": { | |
"directory": dir, | |
"model_name": "Apollo-0.5B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-0.5B-GGUF/resolve/main/Apollo-0.5B-q8_0.gguf" | |
}, | |
"Apollo2 2B": { | |
"directory": dir, | |
"model_name": "Apollo-2B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-2B-GGUF/resolve/main/Apollo-2B-q8_0.gguf" | |
}, | |
"Apollo2 7B": { | |
"directory": dir, | |
"model_name": "apollo2-7b-q8_0.gguf", | |
"model_url": "https://huggingface.co/nchen909/Apollo2-7B-Q8_0-GGUF/resolve/main/apollo2-7b-q8_0.gguf" | |
} | |
} | |
MODEL = None | |
CURRENT_MODEL_KEY = None | |
def get_model_key(model_type, model_size): | |
return f"{model_type} {model_size}" | |
def initialize_model(model_type, model_size): | |
"""Load the selected model dynamically.""" | |
global MODEL, CURRENT_MODEL_KEY | |
model_key = get_model_key(model_type, model_size) | |
# Only reload the model if it's not already loaded | |
if CURRENT_MODEL_KEY == model_key and MODEL is not None: | |
print(f"Model {model_key} is already loaded.") | |
return | |
print(f"Initializing model: {model_key}") | |
try: | |
selected_model = MODEL_OPTIONS[model_key] | |
MODEL = load_model( | |
directory=selected_model["directory"], | |
model_name=selected_model["model_name"], | |
model_url=selected_model["model_url"] | |
) | |
CURRENT_MODEL_KEY = model_key | |
print(f"Model initialized: {model_key}") | |
except Exception as e: | |
print(f"Failed to initialize model {model_key}: {e}") | |
MODEL = None | |
# Functions for chat interactions | |
def user(message, history, model_type, model_size): | |
"""Handle user input and dynamically initialize the selected model.""" | |
global MODEL | |
# Dynamically initialize the selected model | |
initialize_model(model_type, model_size) | |
new_history = history + [[message, None]] | |
return "", new_history | |
def bot(history, top_p, top_k, temp): | |
"""Generate a response from the bot based on chat history.""" | |
global MODEL | |
if MODEL is None: | |
raise RuntimeError("Model has not been initialized. Please select a model to load.") | |
model = MODEL | |
tokens = get_system_tokens(model)[:] | |
for user_message, bot_message in history[:-1]: | |
tokens.extend(get_message_tokens(model=model, role="user", content=user_message)) | |
if bot_message: | |
tokens.extend(get_message_tokens(model=model, role="bot", content=bot_message)) | |
last_user_message = history[-1][0] | |
tokens.extend(get_message_tokens(model=model, role="user", content=last_user_message)) | |
tokens.extend(model.tokenize("bot\n".encode("utf-8"), special=True)) | |
generator = model.generate(tokens, top_k=top_k, top_p=top_p, temp=temp) | |
partial_text = "" | |
for i, token in enumerate(generator): | |
if token == model.token_eos(): | |
break | |
partial_text += model.detokenize([token]).decode("utf-8", "ignore") | |
history[-1][1] = partial_text | |
yield history | |
def clear_chat(): | |
"""Clear the chat history.""" | |
return [] | |
def stop_generation(): | |
"""Placeholder to stop generation.""" | |
print("Generation stopped.") # Implement stop logic if supported | |
return None | |
# Gradio UI | |
with gr.Blocks(theme=gr.themes.Monochrome(), analytics_enabled=False) as demo: | |
favicon = '<img src="https://huggingface.co/FreedomIntelligence/Apollo2-7B/resolve/main/assets/apollo_medium_final.png" width="148px" style="display: inline">' | |
gr.Markdown( | |
f"""# {favicon} Apollo GGUF Playground | |
This is a demo of multilingual medical model series **[Apollo](https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF)**, GGUF version. [Apollo1](https://arxiv.org/abs/2403.03640) covers 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) covers 50 languages. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Conversation") | |
msg = gr.Textbox( | |
label="Send Message", | |
placeholder="Send Message", | |
show_label=False, | |
elem_id="send-message-box" | |
) | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=False): | |
model_type = gr.Dropdown( | |
choices=["Apollo", "Apollo2"], | |
value="Apollo2", | |
label="Select Model", | |
interactive=True, | |
elem_id="model-type-dropdown", | |
) | |
model_size = gr.Dropdown( | |
choices=["0.5B", "2B", "7B"], | |
value="7B", | |
label="Select Size", | |
interactive=True, | |
elem_id="model-size-dropdown", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
interactive=True, | |
label="Top-p", | |
) | |
top_k = gr.Slider( | |
minimum=10, | |
maximum=100, | |
value=30, | |
step=5, | |
interactive=True, | |
label="Top-k", | |
) | |
temp = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.01, | |
step=0.01, | |
interactive=True, | |
label="Temperature" | |
) | |
with gr.Row(equal_height=False): | |
submit = gr.Button("Send", elem_id="send-btn") | |
stop = gr.Button("Stop", elem_id="stop-btn") | |
clear = gr.Button("Clear", elem_id="clear-btn") | |
# Event bindings | |
submit_event = msg.submit( | |
fn=user, | |
inputs=[msg, chatbot, model_type, model_size], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=bot, | |
inputs=[chatbot, top_p, top_k, temp], | |
outputs=chatbot, | |
queue=True, | |
) | |
submit_click_event = submit.click( | |
fn=user, | |
inputs=[msg, chatbot, model_type, model_size], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=bot, | |
inputs=[chatbot, top_p, top_k, temp], | |
outputs=chatbot, | |
queue=True, | |
) | |
stop.click( | |
fn=stop_generation, | |
inputs=None, | |
outputs=None, | |
cancels=[submit_event, submit_click_event], | |
queue=False, | |
) | |
clear.click(fn=clear_chat, inputs=None, outputs=chatbot, queue=False) | |
demo.queue(max_size=128) | |
demo.css = """ | |
footer {display: none !important;} | |
#send-message-box {width: 100%;} | |
#send-btn, #stop-btn, #clear-btn { | |
display: inline-block; | |
width: 30%; | |
margin-right: 2px; | |
text-align: center; | |
} | |
.gr-row { | |
display: flex !important; | |
flex-direction: row !important; | |
justify-content: space-between; | |
align-items: center; | |
flex-wrap: nowrap; | |
} | |
""" | |
# Initialize | |
# Initialize the default model at startup | |
#initialize_model("Apollo2", "7B") | |
demo.launch(show_error=True, share=True) | |