ngxson's picture
ngxson HF staff
init version
cfcd3f9
raw
history blame
6.09 kB
import os
import requests
import subprocess
import json
import time
import spaces
import gradio as gr
from typing import List, Optional, Tuple, Dict
DEFAULT_SYSTEM = "You are a helpful assistant."
HF_MODEL_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF"
HF_FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
###############################################
API_PATH_HEALTH = "/health"
API_PATH_COMPLETIONS = "/chat/completions"
LLAMA_CPP_SERVER_BASE = "http://127.0.0.1:8080"
LLAMA_CPP_SERVER_START_TIMEOUT = 50 # seconds
if not os.path.exists('model.gguf'):
url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/{HF_FILENAME}"
subprocess.check_call(["curl", "-o", "model.gguf", "-L", url])
if not os.path.exists("llama-server"):
# FIXME: currently, we can't build inside gradio container because nvcc is missing
subprocess.check_call("curl -o llama-server -L https://ngxson-llamacpp-builder.hf.space/llama-server", shell=True)
subprocess.check_call("chmod +x llama-server", shell=True)
###############################################
class Role:
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]
def clear_session() -> History:
return "", []
def modify_system_session(system: str) -> str:
if system is None or len(system) == 0:
system = DEFAULT_SYSTEM
return system, system, []
def history_to_messages(history: History, system: str) -> Messages:
messages = [{"role": Role.SYSTEM, "content": system}]
for h in history:
messages.append({"role": Role.USER, "content": h[0]})
messages.append({"role": Role.ASSISTANT, "content": h[1]})
return messages
def messages_to_history(messages: Messages) -> Tuple[str, History]:
assert messages[0]["role"] == Role.SYSTEM
system = messages[0]["content"]
history = []
for q, r in zip(messages[1::2], messages[2::2]):
history.append([q["content"], r["content"]])
return system, history
def wait_until_llamacpp_ready():
time.sleep(5)
gr.Info("starting llama.cpp server...")
trials = 0
while True:
try:
response = requests.get(LLAMA_CPP_SERVER_BASE + API_PATH_HEALTH)
if response.status_code == 200:
print("Status 200 received. Exiting loop.")
break
else:
print(f"Received status {response.status_code}. Retrying...")
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
trials += 1
if trials > LLAMA_CPP_SERVER_START_TIMEOUT:
raise TimeoutError("llama.cpp server did not start in time")
time.sleep(1) # Wait for 1 second before retrying
gr.Info("llama.cpp server is ready.")
print("llama.cpp server is ready.")
@spaces.GPU
def model_chat(query: Optional[str], history: Optional[History], system: str
) -> Tuple[str, str, History]:
if query is None:
query = ""
if history is None:
history = []
# start llama.cpp server
proc = subprocess.Popen(["./llama-server"], env=dict(
os.environ,
LLAMA_HOST="0.0.0.0",
LLAMA_PORT="8080",
LLAMA_ARG_CTX_SIZE=str(1024 * 32),
LLAMA_ARG_FLASH_ATTN="1",
LLAMA_ARG_MODEL="model.gguf",
LLAMA_ARG_N_PARALLEL="1",
LLAMA_ARG_N_GPU_LAYERS="9999",
LLAMA_ARG_NO_MMAP="1",
))
exception = None
try:
wait_until_llamacpp_ready()
messages = history_to_messages(history, system)
messages.append({"role": Role.USER, "content": query})
# adapted from https://gist.github.com/ggorlen/7c944d73e27980544e29aa6de1f2ac54
url = LLAMA_CPP_SERVER_BASE + API_PATH_COMPLETIONS
headers = {
# "Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"temperature": 0.8,
"messages": messages,
"stream": True
}
response = requests.post(url, headers=headers, json=data, stream=True)
response.raise_for_status()
curr_text = ""
for line in response.iter_lines():
line = line.decode("utf-8")
if line.startswith("data: ") and not line.endswith("[DONE]"):
data = json.loads(line[len("data: "):])
chunk = data["choices"][0]["delta"].get("content", "")
# print(chunk, end="", flush=True)
curr_text += chunk
system, history = messages_to_history(messages + [{"role": Role.ASSISTANT, "content": curr_text}])
yield "", history, system
except Exception as e:
print(e)
exception = e
finally:
# clean up
proc.kill()
if exception is not None:
# re-raise the exception if needed
raise exception
with gr.Blocks() as demo:
gr.Markdown(f"""<center><font size=6>{HF_MODEL_ID}</center>""")
with gr.Row():
with gr.Column(scale=3):
system_input = gr.Textbox(value=DEFAULT_SYSTEM, lines=1, label="System")
with gr.Column(scale=1):
modify_system = gr.Button("๐Ÿ› ๏ธ Set system prompt and clear history", scale=2)
system_state = gr.Textbox(value=DEFAULT_SYSTEM, visible=False)
chatbot = gr.Chatbot(label=HF_MODEL_ID)
textbox = gr.Textbox(lines=2, label="Input")
with gr.Row():
clear_history = gr.Button("๐Ÿงน Clear history")
sumbit = gr.Button("๐Ÿš€ Send")
sumbit.click(model_chat,
inputs=[textbox, chatbot, system_state],
outputs=[textbox, chatbot, system_input],
concurrency_limit = 5)
clear_history.click(fn=clear_session,
inputs=[],
outputs=[textbox, chatbot])
modify_system.click(fn=modify_system_session,
inputs=[system_input],
outputs=[system_state, system_input, chatbot])
demo.queue(api_open=False)
demo.launch(max_threads=5)