Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from huggingface_hub.file_download import http_get | |
from llama_cpp import Llama | |
SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." | |
# Define the directory dynamically | |
dir = "." | |
def get_message_tokens(model, role, content): | |
content = f"{role}\n{content}\n</s>" | |
content = content.encode("utf-8") | |
return model.tokenize(content, special=True) | |
def get_system_tokens(model): | |
system_message = {"role": "system", "content": SYSTEM_PROMPT} | |
return get_message_tokens(model, **system_message) | |
def load_model(directory, model_name, model_url): | |
final_model_path = os.path.join(directory, model_name) | |
print(f"Checking model: {model_name}") | |
if not os.path.exists(final_model_path): | |
print(f"Downloading model: {model_name}") | |
with open(final_model_path, "wb") as f: | |
http_get(model_url, f) | |
os.chmod(final_model_path, 0o777) | |
print(f"Model {model_name} ready!") | |
model = Llama(model_path=final_model_path, n_ctx=1024) | |
print(f"Model {model_name} loaded successfully!") | |
return model | |
MODEL_OPTIONS = { | |
"Apollo 0.5B": { | |
"directory": dir, | |
"model_name": "apollo-0.5b.gguf", | |
"model_url": "https://huggingface.co/path_to_apollo_0.5b_model" | |
}, | |
"Apollo 2B": { | |
"directory": dir, | |
"model_name": "apollo-2b.gguf", | |
"model_url": "https://huggingface.co/path_to_apollo_2b_model" | |
}, | |
"Apollo 7B": { | |
"directory": dir, | |
"model_name": "Apollo-7B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/resolve/main/Apollo-7B-q8_0.gguf" | |
}, | |
"Apollo2 0.5B": { | |
"directory": dir, | |
"model_name": "Apollo-0.5B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-0.5B-GGUF/resolve/main/Apollo-0.5B-q8_0.gguf" | |
}, | |
"Apollo2 2B": { | |
"directory": dir, | |
"model_name": "Apollo-2B-q8_0.gguf", | |
"model_url": "https://huggingface.co/FreedomIntelligence/Apollo-2B-GGUF/resolve/main/Apollo-2B-q8_0.gguf" | |
}, | |
"Apollo2 7B": { | |
"directory": dir, | |
"model_name": "apollo2-7b-q8_0.gguf", | |
"model_url": "https://huggingface.co/nchen909/Apollo2-7B-Q8_0-GGUF/resolve/main/apollo2-7b-q8_0.gguf" | |
} | |
} | |
MODEL = None | |
def get_model_key(model_type, model_size): | |
return f"{model_type} {model_size}" | |
def initialize_model(model_type="Apollo2", model_size="7B"): | |
global MODEL | |
model_key = get_model_key(model_type, model_size) | |
try: | |
print(f"Initializing model: {model_key}") | |
selected_model = MODEL_OPTIONS[model_key] | |
MODEL = load_model( | |
directory=selected_model["directory"], | |
model_name=selected_model["model_name"], | |
model_url=selected_model["model_url"] | |
) | |
print(f"Model initialized: {model_key}") | |
except Exception as e: | |
print(f"Failed to initialize model {model_key}: {e}") | |
MODEL = None | |
def user(message, history): | |
new_history = history + [[message, None]] | |
return "", new_history | |
def bot(history, top_p, top_k, temp): | |
global MODEL | |
if MODEL is None: | |
raise RuntimeError("Model has not been initialized. Please select a model to load.") | |
model = MODEL | |
tokens = get_system_tokens(model)[:] | |
for user_message, bot_message in history[:-1]: | |
message_tokens = get_message_tokens(model=model, role="user", content=user_message) | |
tokens.extend(message_tokens) | |
if bot_message: | |
message_tokens = get_message_tokens(model=model, role="bot", content=bot_message) | |
tokens.extend(message_tokens) | |
last_user_message = history[-1][0] | |
message_tokens = get_message_tokens(model=model, role="user", content=last_user_message) | |
tokens.extend(message_tokens) | |
role_tokens = model.tokenize("bot\n".encode("utf-8"), special=True) | |
tokens.extend(role_tokens) | |
generator = model.generate(tokens, top_k=top_k, top_p=top_p, temp=temp) | |
partial_text = "" | |
for i, token in enumerate(generator): | |
if token == model.token_eos(): | |
break | |
partial_text += model.detokenize([token]).decode("utf-8", "ignore") | |
history[-1][1] = partial_text | |
yield history | |
with gr.Blocks( | |
theme=gr.themes.Monochrome(), | |
analytics_enabled=False, | |
) as demo: | |
favicon = '<img src="https://huggingface.co/FreedomIntelligence/Apollo2-7B/resolve/main/assets/apollo_medium_final.png" width="148px" style="display: inline">' | |
gr.Markdown( | |
f"""# {favicon} Apollo GGUF Playground | |
This is a demo of multilingual medical model series **[Apollo](https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF)**, GGUF version. [Apollo1](https://arxiv.org/abs/2403.03640) covers 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) covers 50 languages. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Conversation") | |
msg = gr.Textbox( | |
label="Send Message", | |
placeholder="Send Message", | |
show_label=False, | |
elem_id="send-message-box" | |
) | |
with gr.Column(scale=1): | |
# 将 model_type 和 model_size 包含在同一个 gr.Row 中 | |
with gr.Row(equal_height=False): | |
model_type = gr.Dropdown( | |
choices=["Apollo", "Apollo2"], | |
value="Apollo2", | |
label="Select Model", | |
interactive=True, | |
elem_id="model-type-dropdown", | |
) | |
model_size = gr.Dropdown( | |
choices=["0.5B", "2B", "7B"], | |
value="7B", | |
label="Select Size", | |
interactive=True, | |
elem_id="model-size-dropdown", | |
) | |
#gr.Markdown("### Generation Parameters") | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
interactive=True, | |
label="Top-p", | |
) | |
top_k = gr.Slider( | |
minimum=10, | |
maximum=100, | |
value=30, | |
step=5, | |
interactive=True, | |
label="Top-k", | |
) | |
temp = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.01, | |
step=0.01, | |
interactive=True, | |
label="Temperature" | |
) | |
with gr.Row(equal_height=False): | |
submit = gr.Button("Send", elem_id="send-btn") | |
stop = gr.Button("Stop", elem_id="stop-btn") | |
clear = gr.Button("Clear", elem_id="clear-btn") | |
def update_model(model_type, model_size): | |
initialize_model(model_type, model_size) | |
model_type.change(update_model, [model_type, model_size], None) | |
model_size.change(update_model, [model_type, model_size], None) | |
msg.submit( | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=bot, | |
inputs=[chatbot, top_p, top_k, temp], | |
outputs=chatbot, | |
queue=True, | |
) | |
demo.queue(max_size=128) | |
demo.css = """ | |
footer {display: none !important;} | |
#send-message-box {width: 100%;} | |
#send-btn, #stop-btn, #clear-btn { | |
display: inline-block; /* 强制内联块 */ | |
width: 30%; /* 设置按钮宽度为父容器的 30% */ | |
margin-right: 2px; /* 按钮之间增加间距 */ | |
text-align: center; /* 按钮内容居中 */ | |
} | |
.gr-row { | |
display: flex !important; /* 强制使用 flex 布局 */ | |
flex-direction: row !important; /* 水平排列 */ | |
justify-content: space-between; /* 组件之间的间距调整 */ | |
align-items: center; /* 垂直居中对齐 */ | |
flex-wrap: nowrap; /* 禁止按钮换行 */ | |
} | |
""" | |
# Initialize the default model at startup | |
initialize_model("Apollo2", "7B") | |
demo.launch(show_error=True, share=True) | |