Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
import boto3
|
|
|
3 |
import json
|
4 |
import io
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
region = os.getenv("region")
|
8 |
sm_endpoint_name = os.getenv("sm_endpoint_name")
|
@@ -16,20 +19,28 @@ session = boto3.Session(
|
|
16 |
)
|
17 |
sess = sagemaker.Session(boto_session=session)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# hyperparameters for llm
|
20 |
parameters = {
|
21 |
"do_sample": True,
|
22 |
"top_p": 0.6,
|
23 |
"temperature": 0.9,
|
24 |
-
"max_new_tokens":
|
|
|
25 |
"return_full_text": False,
|
26 |
-
"stop": ["</s>"],
|
27 |
}
|
28 |
|
29 |
-
system_prompt = (
|
30 |
-
"You are an helpful Assistant, called Llama 2. Knowing everyting about AWS."
|
31 |
-
)
|
32 |
-
|
33 |
|
34 |
# Helper for reading lines from a stream
|
35 |
class LineIterator:
|
@@ -61,66 +72,86 @@ class LineIterator:
|
|
61 |
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
62 |
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
messages = []
|
67 |
-
|
68 |
-
messages.append({"role": "
|
69 |
-
|
70 |
-
messages.append({"role": "user", "content": user_prompt})
|
71 |
-
messages.append({"role": "assistant", "content": bot_response})
|
72 |
messages.append({"role": "user", "content": message})
|
73 |
-
|
|
|
|
|
|
|
74 |
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
system_prompt=system_prompt,
|
81 |
-
tokenizer=None,
|
82 |
-
concurrency_count=4,
|
83 |
-
share=True,
|
84 |
):
|
85 |
-
|
86 |
-
|
87 |
-
def generate(
|
88 |
-
prompt,
|
89 |
-
history,
|
90 |
-
):
|
91 |
-
messages = create_messages_dict(prompt, history, system_prompt)
|
92 |
-
formatted_prompt = tokenizer.apply_chat_template(
|
93 |
-
messages, tokenize=False, add_generation_prompt=True
|
94 |
-
)
|
95 |
|
96 |
-
|
97 |
-
resp = smr.invoke_endpoint_with_response_stream(
|
98 |
-
EndpointName=endpoint_name,
|
99 |
-
Body=json.dumps(request),
|
100 |
-
ContentType="application/json",
|
101 |
-
)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
if chunk["token"]["special"]:
|
109 |
-
continue
|
110 |
-
if chunk["token"]["text"] in request["parameters"]["stop"]:
|
111 |
-
break
|
112 |
-
output += chunk["token"]["text"]
|
113 |
-
for stop_str in request["parameters"]["stop"]:
|
114 |
-
if output.endswith(stop_str):
|
115 |
-
output = output[: -len(stop_str)]
|
116 |
-
output = output.rstrip()
|
117 |
-
yield output
|
118 |
-
|
119 |
-
yield output
|
120 |
-
return output
|
121 |
-
|
122 |
-
demo = gr.ChatInterface(
|
123 |
-
generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel")
|
124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
|
|
|
1 |
import gradio as gr
|
2 |
import boto3
|
3 |
+
import sagemaker
|
4 |
import json
|
5 |
import io
|
6 |
import os
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
|
10 |
region = os.getenv("region")
|
11 |
sm_endpoint_name = os.getenv("sm_endpoint_name")
|
|
|
19 |
)
|
20 |
sess = sagemaker.Session(boto_session=session)
|
21 |
|
22 |
+
smr = session.client("sagemaker-runtime")
|
23 |
+
|
24 |
+
DEFAULT_SYSTEM_PROMPT = (
|
25 |
+
"You are an helpful, concise and direct Assistant."
|
26 |
+
)
|
27 |
+
|
28 |
+
# load the tokenizer
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
30 |
+
|
31 |
+
|
32 |
+
MAX_INPUT_TOKEN_LENGTH = 256
|
33 |
+
|
34 |
# hyperparameters for llm
|
35 |
parameters = {
|
36 |
"do_sample": True,
|
37 |
"top_p": 0.6,
|
38 |
"temperature": 0.9,
|
39 |
+
"max_new_tokens": 768,
|
40 |
+
"repetition_penalty": 1.2,
|
41 |
"return_full_text": False,
|
|
|
42 |
}
|
43 |
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Helper for reading lines from a stream
|
46 |
class LineIterator:
|
|
|
72 |
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
73 |
|
74 |
|
75 |
+
def format_prompt(message, history):
|
76 |
+
'''
|
77 |
+
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
78 |
+
for interaction in history:
|
79 |
+
messages.append({"role": "user", "content": interaction[0]})
|
80 |
+
messages.append({"role": "assistant", "content": interaction[1]})
|
|
|
|
|
81 |
messages.append({"role": "user", "content": message})
|
82 |
+
prompt = tokenizer.apply_chat_template(
|
83 |
+
messages, tokenize=False, add_generation_prompt=True
|
84 |
+
)
|
85 |
+
'''
|
86 |
|
87 |
+
messages = [
|
88 |
+
{"role": "user", "content": "Can you tell me an interesting fact about AWS?"},]
|
89 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
90 |
+
|
91 |
+
return prompt
|
92 |
|
93 |
+
|
94 |
+
def generate(
|
95 |
+
prompt,
|
96 |
+
history,
|
|
|
|
|
|
|
|
|
97 |
):
|
98 |
+
formatted_prompt = format_prompt(prompt, history)
|
99 |
+
check_input_token_length(formatted_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
|
104 |
+
resp = smr.invoke_endpoint_with_response_stream(
|
105 |
+
EndpointName=endpoint_name,
|
106 |
+
Body=json.dumps(request),
|
107 |
+
ContentType="application/json",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
)
|
109 |
+
output = ""
|
110 |
+
for c in LineIterator(resp["Body"]):
|
111 |
+
c = c.decode("utf-8")
|
112 |
+
if c.startswith("data:"):
|
113 |
+
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
|
114 |
+
if chunk["token"]["special"]:
|
115 |
+
continue
|
116 |
+
if chunk["token"]["text"] in request["parameters"]["stop"]:
|
117 |
+
break
|
118 |
+
output += chunk["token"]["text"]
|
119 |
+
for stop_str in request["parameters"]["stop"]:
|
120 |
+
if output.endswith(stop_str):
|
121 |
+
output = output[: -len(stop_str)]
|
122 |
+
output = output.rstrip()
|
123 |
+
yield output
|
124 |
+
|
125 |
+
yield output
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
def check_input_token_length(prompt: str) -> None:
|
130 |
+
input_token_length = len(tokenizer(prompt)["input_ids"])
|
131 |
+
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
132 |
+
raise gr.Error(
|
133 |
+
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
theme = gr.themes.Monochrome(
|
138 |
+
primary_hue="indigo",
|
139 |
+
secondary_hue="blue",
|
140 |
+
neutral_hue="slate",
|
141 |
+
radius_size=gr.themes.sizes.radius_sm,
|
142 |
+
font=[
|
143 |
+
gr.themes.GoogleFont("Open Sans"),
|
144 |
+
"ui-sans-serif",
|
145 |
+
"system-ui",
|
146 |
+
"sans-serif",
|
147 |
+
],
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
demo = gr.ChatInterface(
|
152 |
+
generate,
|
153 |
+
chatbot=gr.Chatbot(layout="panel"),
|
154 |
+
theme=theme,
|
155 |
+
)
|
156 |
|
157 |
+
demo.queue(concurrency_count=5).launch(share=False)
|