Lohia, Aditya commited on
Commit
a9409d4
1 Parent(s): 12e4d9f

Updated Spaces

Browse files
Files changed (4) hide show
  1. app.py +147 -0
  2. dialog.py +45 -0
  3. gateway.py +90 -0
  4. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from typing import Iterator
4
+
5
+ from dialog import get_dialog_box
6
+ from gateway import check_server_health, request_generation
7
+
8
+ # CONSTANTS
9
+ MAX_NEW_TOKENS: int = 2048
10
+
11
+ # GET ENVIRONMENT VARIABLES
12
+ CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
13
+
14
+
15
+ def toggle_ui():
16
+ """
17
+ Function to toggle the visibility of the UI based on the server health
18
+ Returns:
19
+ hide/show main ui/dialog
20
+ """
21
+ health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API)
22
+ if health:
23
+ return gr.update(visible=True), gr.update(visible=False) # Show main UI, hide dialog
24
+ else:
25
+ return gr.update(visible=False), gr.update(visible=True) # Hide main UI, show dialog
26
+
27
+
28
+ def generate(
29
+ message: str,
30
+ chat_history: list,
31
+ system_prompt: str,
32
+ max_new_tokens: int = 1024,
33
+ temperature: float = 0.6,
34
+ top_p: float = 0.9,
35
+ top_k: int = 50,
36
+ repetition_penalty: float = 1.2,
37
+ ) -> Iterator[str]:
38
+ """Send a request to backend, fetch the streaming responses and emit to the UI.
39
+
40
+ Args:
41
+ message (str): input message from the user
42
+ chat_history (list[tuple[str, str]]): entire chat history of the session
43
+ system_prompt (str): system prompt
44
+ max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
45
+ prompt. Defaults to 1024.
46
+ temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
47
+ top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
48
+ that add up to top_p or higher are kept for generation. Defaults to 0.9.
49
+ top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
50
+ Defaults to 50.
51
+ repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
52
+ Defaults to 1.2.
53
+
54
+ Yields:
55
+ Iterator[str]: Streaming responses to the UI
56
+ """
57
+ # sample method to yield responses from the llm model
58
+ outputs = []
59
+ for text in request_generation(message=message,
60
+ system_prompt=system_prompt,
61
+ max_new_tokens=max_new_tokens,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ repetition_penalty=repetition_penalty,
66
+ cloud_gateway_api=CLOUD_GATEWAY_API):
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
+
70
+
71
+ chat_interface = gr.ChatInterface(
72
+ fn=generate,
73
+ additional_inputs=[
74
+ gr.Textbox(label="System prompt", lines=6),
75
+ gr.Slider(
76
+ label="Max New Tokens",
77
+ minimum=1,
78
+ maximum=MAX_NEW_TOKENS,
79
+ step=1,
80
+ value=1024,
81
+ ),
82
+ gr.Slider(
83
+ label="Temperature",
84
+ minimum=0.1,
85
+ maximum=4.0,
86
+ step=0.1,
87
+ value=0.1,
88
+ ),
89
+ gr.Slider(
90
+ label="Top-p (nucleus sampling)",
91
+ minimum=0.05,
92
+ maximum=1.0,
93
+ step=0.05,
94
+ value=0.95,
95
+ ),
96
+ gr.Slider(
97
+ label="Top-k",
98
+ minimum=1,
99
+ maximum=1000,
100
+ step=1,
101
+ value=50,
102
+ ),
103
+ gr.Slider(
104
+ label="Repetition penalty",
105
+ minimum=1.0,
106
+ maximum=2.0,
107
+ step=0.05,
108
+ value=1.2,
109
+ ),
110
+ ],
111
+ stop_btn=None,
112
+ examples=[
113
+ ["Hello there! How are you doing?"],
114
+ ["Can you explain briefly to me what is the Python programming language?"],
115
+ ["Explain the plot of Cinderella in a sentence."],
116
+ ["How many hours does it take a man to eat a Helicopter?"],
117
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
118
+ ],
119
+ cache_examples=False,
120
+ chatbot=gr.Chatbot(
121
+ height=600)
122
+ )
123
+
124
+ with gr.Blocks(css="style.css", theme=gr.themes.Default()) as demo:
125
+ # Get the server status before displaying UI
126
+ visibility = check_server_health(CLOUD_GATEWAY_API)
127
+
128
+ # Container for the main interface
129
+ with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
130
+ gr.Markdown(f"""
131
+ # Llama-3 8B Chat
132
+ This Space is an Alpha release that demonstrates model [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, a Llama 3 model with 8B parameters fine-tuned for chat instructions, running on AMD MI210 infrastructure. Feel free to play with it!
133
+ """)
134
+ chat_interface.render()
135
+
136
+ # Dialog box using Markdown for the error message
137
+ with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box:
138
+ # Add spinner and message
139
+ get_dialog_box()
140
+
141
+ # Timer to check server health every 5 seconds and update UI
142
+ timer = gr.Timer(value=10)
143
+ timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box])
144
+
145
+
146
+ if __name__ == "__main__":
147
+ demo.queue(max_size=int(os.getenv("QUEUE"))).launch()
dialog.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def get_dialog_box():
5
+ return gr.HTML("""
6
+ <div style="display: flex; align-items: center; justify-content: center; min-height: 80vh;">
7
+ <div style="display: flex; flex-direction: column; align-items: center;">
8
+ <!-- Spinner -->
9
+ <div class="loader" style="margin-top: 20px;"></div>
10
+ <!-- Message -->
11
+ <h2 style="color: orange; font-family: trebuchet ms, sans-serif; align-items: center;">The service is not working, please refresh or try again later!</h2>
12
+ </div>
13
+ </div>
14
+
15
+ <!-- Spinner CSS -->
16
+ <style>
17
+ /* HTML: <div class="loader"></div> */
18
+ .loader {
19
+ width: 120px;
20
+ height: 22px;
21
+ border-radius: 40px;
22
+ color: orange !important;
23
+ border: 2px solid;
24
+ position: relative;
25
+ overflow: hidden;
26
+ }
27
+ .loader::before {
28
+ content: "";
29
+ position: absolute;
30
+ margin: 2px;
31
+ width: 14px;
32
+ top: 0;
33
+ bottom: 0;
34
+ left: -20px;
35
+ border-radius: inherit;
36
+ background: currentColor;
37
+ box-shadow: -10px 0 12px 3px currentColor;
38
+ clip-path: polygon(0 5%, 100% 0,100% 100%,0 95%,-30px 50%);
39
+ animation: l14 1s infinite linear;
40
+ }
41
+ @keyframes l14 {
42
+ 100% {left: calc(100% + 20px)}
43
+ }
44
+ </style>
45
+ """)
gateway.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+
4
+
5
+ def check_server_health(cloud_gateway_api: str):
6
+ """
7
+ Use the appropriate API endpoint to check the server health.
8
+ Args:
9
+ cloud_gateway_api: API endpoint to probe.
10
+
11
+ Returns:
12
+ True if server is active, false otherwise.
13
+ """
14
+ try:
15
+ response = requests.get(cloud_gateway_api + "/health")
16
+ if response.status_code == 200:
17
+ return True
18
+ except requests.ConnectionError:
19
+ print("Failed to establish connection to the server.")
20
+
21
+ return False
22
+
23
+
24
+ def request_generation(message: str,
25
+ system_prompt: str,
26
+ cloud_gateway_api: str,
27
+ max_new_tokens: int = 1024,
28
+ temperature: float = 0.6,
29
+ top_p: float = 0.9,
30
+ top_k: int = 50,
31
+ repetition_penalty: float = 1.2, ):
32
+ """
33
+ Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
34
+ token-by-token generation from LLM.
35
+
36
+ Args:
37
+ message: prompt from the user.
38
+ system_prompt: system prompt to append.
39
+ cloud_gateway_api (str): API endpoint to send the request.
40
+ max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
41
+ temperature: the value used to module the next token probabilities.
42
+ top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
43
+ or higher are kept for generation.
44
+ top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
45
+ repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
46
+
47
+ Returns:
48
+
49
+ """
50
+
51
+ payload = {
52
+ "model": "meta-llama/Meta-Llama-3-8B-Instruct",
53
+ "messages": [
54
+ {"role": "system", "content": system_prompt},
55
+ {"role": "user", "content": message}
56
+ ],
57
+ "max_tokens": max_new_tokens,
58
+ "temperature": temperature,
59
+ "top_p": top_p,
60
+ "repetition_penalty": repetition_penalty,
61
+ "top_k": top_k,
62
+ "stream": True # Enable streaming
63
+ }
64
+
65
+ with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response:
66
+ for chunk in response.iter_lines():
67
+ if chunk:
68
+ # Convert the chunk from bytes to a string and then parse it as json
69
+ chunk_str = chunk.decode('utf-8')
70
+
71
+ # Remove the `data: ` prefix from the chunk if it exists
72
+ if chunk_str.startswith("data: "):
73
+ chunk_str = chunk_str[len("data: "):]
74
+
75
+ # Skip empty chunks
76
+ if chunk_str.strip() == "[DONE]":
77
+ break
78
+
79
+ # Parse the chunk into a JSON object
80
+ try:
81
+ chunk_json = json.loads(chunk_str)
82
+ # Extract the "content" field from the choices
83
+ content = chunk_json["choices"][0]["delta"].get("content", "")
84
+
85
+ # Print the generated content as it's streamed
86
+ if content:
87
+ yield content
88
+ except json.JSONDecodeError:
89
+ # Handle any potential errors in decoding
90
+ continue
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ pillow==10.4.0
3
+ gradio==4.43.0
4
+ fastapi==0.111.1
5
+ websockets==11.0.3