Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import requests | |
import subprocess | |
import json | |
import time | |
import spaces | |
import gradio as gr | |
from typing import List, Optional, Tuple, Dict | |
DEFAULT_SYSTEM = "You are a helpful assistant." | |
HF_MODEL_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF" | |
HF_FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf" | |
############################################### | |
API_PATH_HEALTH = "/health" | |
API_PATH_COMPLETIONS = "/chat/completions" | |
LLAMA_CPP_SERVER_BASE = "http://127.0.0.1:8080" | |
LLAMA_CPP_SERVER_START_TIMEOUT = 50 # seconds | |
if not os.path.exists('model.gguf'): | |
url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/{HF_FILENAME}" | |
subprocess.check_call(["curl", "-o", "model.gguf", "-L", url]) | |
if not os.path.exists("llama-server"): | |
# FIXME: currently, we can't build inside gradio container because nvcc is missing | |
subprocess.check_call("curl -o llama-server -L https://ngxson-llamacpp-builder.hf.space/llama-server", shell=True) | |
subprocess.check_call("chmod +x llama-server", shell=True) | |
############################################### | |
class Role: | |
SYSTEM = "system" | |
USER = "user" | |
ASSISTANT = "assistant" | |
History = List[Tuple[str, str]] | |
Messages = List[Dict[str, str]] | |
def clear_session() -> History: | |
return "", [] | |
def modify_system_session(system: str) -> str: | |
if system is None or len(system) == 0: | |
system = DEFAULT_SYSTEM | |
return system, system, [] | |
def history_to_messages(history: History, system: str) -> Messages: | |
messages = [{"role": Role.SYSTEM, "content": system}] | |
for h in history: | |
messages.append({"role": Role.USER, "content": h[0]}) | |
messages.append({"role": Role.ASSISTANT, "content": h[1]}) | |
return messages | |
def messages_to_history(messages: Messages) -> Tuple[str, History]: | |
assert messages[0]["role"] == Role.SYSTEM | |
system = messages[0]["content"] | |
history = [] | |
for q, r in zip(messages[1::2], messages[2::2]): | |
history.append([q["content"], r["content"]]) | |
return system, history | |
def wait_until_llamacpp_ready(): | |
time.sleep(5) | |
gr.Info("starting llama.cpp server...") | |
trials = 0 | |
while True: | |
try: | |
response = requests.get(LLAMA_CPP_SERVER_BASE + API_PATH_HEALTH) | |
if response.status_code == 200: | |
print("Status 200 received. Exiting loop.") | |
break | |
else: | |
print(f"Received status {response.status_code}. Retrying...") | |
except requests.exceptions.RequestException as e: | |
print(f"Request failed: {e}") | |
trials += 1 | |
if trials > LLAMA_CPP_SERVER_START_TIMEOUT: | |
raise TimeoutError("llama.cpp server did not start in time") | |
time.sleep(1) # Wait for 1 second before retrying | |
gr.Info("llama.cpp server is ready.") | |
print("llama.cpp server is ready.") | |
def model_chat(query: Optional[str], history: Optional[History], system: str | |
) -> Tuple[str, str, History]: | |
if query is None: | |
query = "" | |
if history is None: | |
history = [] | |
# start llama.cpp server | |
proc = subprocess.Popen(["./llama-server"], env=dict( | |
os.environ, | |
LLAMA_HOST="0.0.0.0", | |
LLAMA_PORT="8080", | |
LLAMA_ARG_CTX_SIZE=str(1024 * 32), | |
LLAMA_ARG_FLASH_ATTN="1", | |
LLAMA_ARG_MODEL="model.gguf", | |
LLAMA_ARG_N_PARALLEL="1", | |
LLAMA_ARG_N_GPU_LAYERS="9999", | |
LLAMA_ARG_NO_MMAP="1", | |
)) | |
exception = None | |
try: | |
wait_until_llamacpp_ready() | |
messages = history_to_messages(history, system) | |
messages.append({"role": Role.USER, "content": query}) | |
# adapted from https://gist.github.com/ggorlen/7c944d73e27980544e29aa6de1f2ac54 | |
url = LLAMA_CPP_SERVER_BASE + API_PATH_COMPLETIONS | |
headers = { | |
# "Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"temperature": 0.8, | |
"messages": messages, | |
"stream": True | |
} | |
response = requests.post(url, headers=headers, json=data, stream=True) | |
response.raise_for_status() | |
curr_text = "" | |
for line in response.iter_lines(): | |
line = line.decode("utf-8") | |
if line.startswith("data: ") and not line.endswith("[DONE]"): | |
data = json.loads(line[len("data: "):]) | |
chunk = data["choices"][0]["delta"].get("content", "") | |
# print(chunk, end="", flush=True) | |
curr_text += chunk | |
system, history = messages_to_history(messages + [{"role": Role.ASSISTANT, "content": curr_text}]) | |
yield "", history, system | |
except Exception as e: | |
print(e) | |
exception = e | |
finally: | |
# clean up | |
proc.kill() | |
if exception is not None: | |
# re-raise the exception if needed | |
raise exception | |
with gr.Blocks() as demo: | |
gr.Markdown(f"""<center><font size=6>{HF_MODEL_ID}</center>""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
system_input = gr.Textbox(value=DEFAULT_SYSTEM, lines=1, label="System") | |
with gr.Column(scale=1): | |
modify_system = gr.Button("๐ ๏ธ Set system prompt and clear history", scale=2) | |
system_state = gr.Textbox(value=DEFAULT_SYSTEM, visible=False) | |
chatbot = gr.Chatbot(label=HF_MODEL_ID) | |
textbox = gr.Textbox(lines=2, label="Input") | |
with gr.Row(): | |
clear_history = gr.Button("๐งน Clear history") | |
sumbit = gr.Button("๐ Send") | |
sumbit.click(model_chat, | |
inputs=[textbox, chatbot, system_state], | |
outputs=[textbox, chatbot, system_input], | |
concurrency_limit = 5) | |
clear_history.click(fn=clear_session, | |
inputs=[], | |
outputs=[textbox, chatbot]) | |
modify_system.click(fn=modify_system_session, | |
inputs=[system_input], | |
outputs=[system_state, system_input, chatbot]) | |
demo.queue(api_open=False) | |
demo.launch(max_threads=5) | |