File size: 6,085 Bytes
cfcd3f9
 
 
 
 
 
9530a4a
cfcd3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9530a4a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import requests
import subprocess
import json
import time
import spaces
import gradio as gr
from typing import List, Optional, Tuple, Dict

DEFAULT_SYSTEM = "You are a helpful assistant."
HF_MODEL_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF"
HF_FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"

###############################################

API_PATH_HEALTH = "/health"
API_PATH_COMPLETIONS = "/chat/completions"
LLAMA_CPP_SERVER_BASE = "http://127.0.0.1:8080"
LLAMA_CPP_SERVER_START_TIMEOUT = 50  # seconds

if not os.path.exists('model.gguf'):
    url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/{HF_FILENAME}"
    subprocess.check_call(["curl", "-o", "model.gguf", "-L", url])

if not os.path.exists("llama-server"):
  # FIXME: currently, we can't build inside gradio container because nvcc is missing
  subprocess.check_call("curl -o llama-server -L https://ngxson-llamacpp-builder.hf.space/llama-server", shell=True)
  subprocess.check_call("chmod +x llama-server", shell=True)

###############################################

class Role:
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]

def clear_session() -> History:
    return "", []

def modify_system_session(system: str) -> str:
    if system is None or len(system) == 0:
        system = DEFAULT_SYSTEM
    return system, system, []

def history_to_messages(history: History, system: str) -> Messages:
    messages = [{"role": Role.SYSTEM, "content": system}]
    for h in history:
        messages.append({"role": Role.USER, "content": h[0]})
        messages.append({"role": Role.ASSISTANT, "content": h[1]})
    return messages


def messages_to_history(messages: Messages) -> Tuple[str, History]:
    assert messages[0]["role"] == Role.SYSTEM
    system = messages[0]["content"]
    history = []
    for q, r in zip(messages[1::2], messages[2::2]):
        history.append([q["content"], r["content"]])
    return system, history

def wait_until_llamacpp_ready():
    time.sleep(5)
    gr.Info("starting llama.cpp server...")
    trials = 0
    while True:
        try:
            response = requests.get(LLAMA_CPP_SERVER_BASE + API_PATH_HEALTH)
            if response.status_code == 200:
                print("Status 200 received. Exiting loop.")
                break
            else:
                print(f"Received status {response.status_code}. Retrying...")
        except requests.exceptions.RequestException as e:
            print(f"Request failed: {e}")
        trials += 1
        if trials > LLAMA_CPP_SERVER_START_TIMEOUT:
            raise TimeoutError("llama.cpp server did not start in time")
        time.sleep(1)  # Wait for 1 second before retrying
    gr.Info("llama.cpp server is ready.")
    print("llama.cpp server is ready.")


@spaces.GPU
def model_chat(query: Optional[str], history: Optional[History], system: str
) -> Tuple[str, str, History]:
    if query is None:
        query = ""
    if history is None:
        history = []

    # start llama.cpp server
    proc = subprocess.Popen(["./llama-server"], env=dict(
        os.environ,
        LLAMA_HOST="0.0.0.0",
        LLAMA_PORT="8080",
        LLAMA_ARG_CTX_SIZE=str(1024 * 32),
        LLAMA_ARG_FLASH_ATTN="1",
        LLAMA_ARG_MODEL="model.gguf",
        LLAMA_ARG_N_PARALLEL="1",
        LLAMA_ARG_N_GPU_LAYERS="9999",
        LLAMA_ARG_NO_MMAP="1",
    ))

    exception = None
    try:
        wait_until_llamacpp_ready()

        messages = history_to_messages(history, system)
        messages.append({"role": Role.USER, "content": query})

        # adapted from https://gist.github.com/ggorlen/7c944d73e27980544e29aa6de1f2ac54
        url = LLAMA_CPP_SERVER_BASE + API_PATH_COMPLETIONS
        headers = {
            # "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        data = {
            "temperature": 0.8,
            "messages": messages,
            "stream": True
        }
        response = requests.post(url, headers=headers, json=data, stream=True)
        response.raise_for_status()

        curr_text = ""
        for line in response.iter_lines():
            line = line.decode("utf-8")

            if line.startswith("data: ") and not line.endswith("[DONE]"):
                data = json.loads(line[len("data: "):])
                chunk = data["choices"][0]["delta"].get("content", "")
                # print(chunk, end="", flush=True)
                curr_text += chunk
                system, history = messages_to_history(messages + [{"role": Role.ASSISTANT, "content": curr_text}])
                yield "", history, system
    except Exception as e:
        print(e)
        exception = e
    finally:
        # clean up
        proc.kill()
        if exception is not None:
            # re-raise the exception if needed
            raise exception


with gr.Blocks() as demo:
    gr.Markdown(f"""<center><font size=6>{HF_MODEL_ID}</center>""")

    with gr.Row():
        with gr.Column(scale=3):
            system_input = gr.Textbox(value=DEFAULT_SYSTEM, lines=1, label="System")
        with gr.Column(scale=1):
            modify_system = gr.Button("🛠️ Set system prompt and clear history", scale=2)
        system_state = gr.Textbox(value=DEFAULT_SYSTEM, visible=False)
    chatbot = gr.Chatbot(label=HF_MODEL_ID)
    textbox = gr.Textbox(lines=2, label="Input")

    with gr.Row():
        clear_history = gr.Button("🧹 Clear history")
        sumbit = gr.Button("🚀 Send")

    sumbit.click(model_chat,
                 inputs=[textbox, chatbot, system_state],
                 outputs=[textbox, chatbot, system_input],
                 concurrency_limit = 5)
    clear_history.click(fn=clear_session,
                        inputs=[],
                        outputs=[textbox, chatbot])
    modify_system.click(fn=modify_system_session,
                        inputs=[system_input],
                        outputs=[system_state, system_input, chatbot])

demo.queue(api_open=False)
demo.launch(max_threads=5)