nchen909's picture
Upload folder using huggingface_hub
a010884 verified
raw
history blame
8.82 kB
import gradio as gr
import os
from huggingface_hub.file_download import http_get
from llama_cpp import Llama
SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them."
# Define the directory dynamically
dir = "."
def get_message_tokens(model, role, content):
content = f"{role}\n{content}\n</s>"
content = content.encode("utf-8")
return model.tokenize(content, special=True)
def get_system_tokens(model):
system_message = {"role": "system", "content": SYSTEM_PROMPT}
return get_message_tokens(model, **system_message)
def load_model(directory, model_name, model_url):
final_model_path = os.path.join(directory, model_name)
print(f"Checking model: {model_name}")
if not os.path.exists(final_model_path):
print(f"Downloading model: {model_name}")
with open(final_model_path, "wb") as f:
http_get(model_url, f)
os.chmod(final_model_path, 0o777)
print(f"Model {model_name} ready!")
model = Llama(model_path=final_model_path, n_ctx=1024)
print(f"Model {model_name} loaded successfully!")
return model
MODEL_OPTIONS = {
"Apollo 0.5B": {
"directory": dir,
"model_name": "apollo-0.5b.gguf",
"model_url": "https://huggingface.co/path_to_apollo_0.5b_model"
},
"Apollo 2B": {
"directory": dir,
"model_name": "apollo-2b.gguf",
"model_url": "https://huggingface.co/path_to_apollo_2b_model"
},
"Apollo 7B": {
"directory": dir,
"model_name": "Apollo-7B-q8_0.gguf",
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/resolve/main/Apollo-7B-q8_0.gguf"
},
"Apollo2 0.5B": {
"directory": dir,
"model_name": "Apollo-0.5B-q8_0.gguf",
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-0.5B-GGUF/resolve/main/Apollo-0.5B-q8_0.gguf"
},
"Apollo2 2B": {
"directory": dir,
"model_name": "Apollo-2B-q8_0.gguf",
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-2B-GGUF/resolve/main/Apollo-2B-q8_0.gguf"
},
"Apollo2 7B": {
"directory": dir,
"model_name": "apollo2-7b-q8_0.gguf",
"model_url": "https://huggingface.co/nchen909/Apollo2-7B-Q8_0-GGUF/resolve/main/apollo2-7b-q8_0.gguf"
}
}
MODEL = None
CURRENT_MODEL_KEY = None
def get_model_key(model_type, model_size):
return f"{model_type} {model_size}"
def initialize_model(model_type, model_size):
"""Load the selected model dynamically."""
global MODEL, CURRENT_MODEL_KEY
model_key = get_model_key(model_type, model_size)
# Only reload the model if it's not already loaded
if CURRENT_MODEL_KEY == model_key and MODEL is not None:
print(f"Model {model_key} is already loaded.")
return
print(f"Initializing model: {model_key}")
try:
selected_model = MODEL_OPTIONS[model_key]
MODEL = load_model(
directory=selected_model["directory"],
model_name=selected_model["model_name"],
model_url=selected_model["model_url"]
)
CURRENT_MODEL_KEY = model_key
print(f"Model initialized: {model_key}")
except Exception as e:
print(f"Failed to initialize model {model_key}: {e}")
MODEL = None
# Functions for chat interactions
def user(message, history, model_type, model_size):
"""Handle user input and dynamically initialize the selected model."""
global MODEL
# Dynamically initialize the selected model
initialize_model(model_type, model_size)
new_history = history + [[message, None]]
return "", new_history
def bot(history, top_p, top_k, temp):
"""Generate a response from the bot based on chat history."""
global MODEL
if MODEL is None:
raise RuntimeError("Model has not been initialized. Please select a model to load.")
model = MODEL
tokens = get_system_tokens(model)[:]
for user_message, bot_message in history[:-1]:
tokens.extend(get_message_tokens(model=model, role="user", content=user_message))
if bot_message:
tokens.extend(get_message_tokens(model=model, role="bot", content=bot_message))
last_user_message = history[-1][0]
tokens.extend(get_message_tokens(model=model, role="user", content=last_user_message))
tokens.extend(model.tokenize("bot\n".encode("utf-8"), special=True))
generator = model.generate(tokens, top_k=top_k, top_p=top_p, temp=temp)
partial_text = ""
for i, token in enumerate(generator):
if token == model.token_eos():
break
partial_text += model.detokenize([token]).decode("utf-8", "ignore")
history[-1][1] = partial_text
yield history
def clear_chat():
"""Clear the chat history."""
return []
def stop_generation():
"""Placeholder to stop generation."""
print("Generation stopped.") # Implement stop logic if supported
return None
# Gradio UI
with gr.Blocks(theme=gr.themes.Monochrome(), analytics_enabled=False) as demo:
favicon = '<img src="https://huggingface.co/FreedomIntelligence/Apollo2-7B/resolve/main/assets/apollo_medium_final.png" width="148px" style="display: inline">'
gr.Markdown(
f"""# {favicon} Apollo GGUF Playground
This is a demo of multilingual medical model series **[Apollo](https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF)**, GGUF version. [Apollo1](https://arxiv.org/abs/2403.03640) covers 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) covers 50 languages.
"""
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversation")
msg = gr.Textbox(
label="Send Message",
placeholder="Send Message",
show_label=False,
elem_id="send-message-box"
)
with gr.Column(scale=1):
with gr.Row(equal_height=False):
model_type = gr.Dropdown(
choices=["Apollo", "Apollo2"],
value="Apollo2",
label="Select Model",
interactive=True,
elem_id="model-type-dropdown",
)
model_size = gr.Dropdown(
choices=["0.5B", "2B", "7B"],
value="7B",
label="Select Size",
interactive=True,
elem_id="model-size-dropdown",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
interactive=True,
label="Top-p",
)
top_k = gr.Slider(
minimum=10,
maximum=100,
value=30,
step=5,
interactive=True,
label="Top-k",
)
temp = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.01,
step=0.01,
interactive=True,
label="Temperature"
)
with gr.Row(equal_height=False):
submit = gr.Button("Send", elem_id="send-btn")
stop = gr.Button("Stop", elem_id="stop-btn")
clear = gr.Button("Clear", elem_id="clear-btn")
# Event bindings
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot, model_type, model_size],
outputs=[msg, chatbot],
queue=False,
).success(
fn=bot,
inputs=[chatbot, top_p, top_k, temp],
outputs=chatbot,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot, model_type, model_size],
outputs=[msg, chatbot],
queue=False,
).success(
fn=bot,
inputs=[chatbot, top_p, top_k, temp],
outputs=chatbot,
queue=True,
)
stop.click(
fn=stop_generation,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(fn=clear_chat, inputs=None, outputs=chatbot, queue=False)
demo.queue(max_size=128)
demo.css = """
footer {display: none !important;}
#send-message-box {width: 100%;}
#send-btn, #stop-btn, #clear-btn {
display: inline-block;
width: 30%;
margin-right: 2px;
text-align: center;
}
.gr-row {
display: flex !important;
flex-direction: row !important;
justify-content: space-between;
align-items: center;
flex-wrap: nowrap;
}
"""
# Initialize
# Initialize the default model at startup
#initialize_model("Apollo2", "7B")
demo.launch(show_error=True, share=True)