cszhzleo commited on
Commit
b5a3b2c
1 Parent(s): 3a2bcc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -61
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import gradio as gr
2
  import boto3
 
3
  import json
4
  import io
5
  import os
 
 
6
 
7
  region = os.getenv("region")
8
  sm_endpoint_name = os.getenv("sm_endpoint_name")
@@ -16,20 +19,28 @@ session = boto3.Session(
16
  )
17
  sess = sagemaker.Session(boto_session=session)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # hyperparameters for llm
20
  parameters = {
21
  "do_sample": True,
22
  "top_p": 0.6,
23
  "temperature": 0.9,
24
- "max_new_tokens": 1024,
 
25
  "return_full_text": False,
26
- "stop": ["</s>"],
27
  }
28
 
29
- system_prompt = (
30
- "You are an helpful Assistant, called Llama 2. Knowing everyting about AWS."
31
- )
32
-
33
 
34
  # Helper for reading lines from a stream
35
  class LineIterator:
@@ -61,66 +72,86 @@ class LineIterator:
61
  self.buffer.write(chunk["PayloadPart"]["Bytes"])
62
 
63
 
64
- # helper method to format prompt
65
- def create_messages_dict(message, history, system_prompt):
66
- messages = []
67
- if system_prompt:
68
- messages.append({"role": "system", "content": system_prompt})
69
- for user_prompt, bot_response in history:
70
- messages.append({"role": "user", "content": user_prompt})
71
- messages.append({"role": "assistant", "content": bot_response})
72
  messages.append({"role": "user", "content": message})
73
- return messages
 
 
 
74
 
 
 
 
 
 
75
 
76
- def create_gradio_app(
77
- endpoint_name,
78
- session=boto3,
79
- parameters=parameters,
80
- system_prompt=system_prompt,
81
- tokenizer=None,
82
- concurrency_count=4,
83
- share=True,
84
  ):
85
- smr = session.client("sagemaker-runtime")
86
-
87
- def generate(
88
- prompt,
89
- history,
90
- ):
91
- messages = create_messages_dict(prompt, history, system_prompt)
92
- formatted_prompt = tokenizer.apply_chat_template(
93
- messages, tokenize=False, add_generation_prompt=True
94
- )
95
 
96
- request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
97
- resp = smr.invoke_endpoint_with_response_stream(
98
- EndpointName=endpoint_name,
99
- Body=json.dumps(request),
100
- ContentType="application/json",
101
- )
102
 
103
- output = ""
104
- for c in LineIterator(resp["Body"]):
105
- c = c.decode("utf-8")
106
- if c.startswith("data:"):
107
- chunk = json.loads(c.lstrip("data:").rstrip("/n"))
108
- if chunk["token"]["special"]:
109
- continue
110
- if chunk["token"]["text"] in request["parameters"]["stop"]:
111
- break
112
- output += chunk["token"]["text"]
113
- for stop_str in request["parameters"]["stop"]:
114
- if output.endswith(stop_str):
115
- output = output[: -len(stop_str)]
116
- output = output.rstrip()
117
- yield output
118
-
119
- yield output
120
- return output
121
-
122
- demo = gr.ChatInterface(
123
- generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel")
124
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- demo.queue(concurrency_count=concurrency_count).launch(share=share)
 
1
  import gradio as gr
2
  import boto3
3
+ import sagemaker
4
  import json
5
  import io
6
  import os
7
+ from transformers import AutoTokenizer
8
+
9
 
10
  region = os.getenv("region")
11
  sm_endpoint_name = os.getenv("sm_endpoint_name")
 
19
  )
20
  sess = sagemaker.Session(boto_session=session)
21
 
22
+ smr = session.client("sagemaker-runtime")
23
+
24
+ DEFAULT_SYSTEM_PROMPT = (
25
+ "You are an helpful, concise and direct Assistant."
26
+ )
27
+
28
+ # load the tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
30
+
31
+
32
+ MAX_INPUT_TOKEN_LENGTH = 256
33
+
34
  # hyperparameters for llm
35
  parameters = {
36
  "do_sample": True,
37
  "top_p": 0.6,
38
  "temperature": 0.9,
39
+ "max_new_tokens": 768,
40
+ "repetition_penalty": 1.2,
41
  "return_full_text": False,
 
42
  }
43
 
 
 
 
 
44
 
45
  # Helper for reading lines from a stream
46
  class LineIterator:
 
72
  self.buffer.write(chunk["PayloadPart"]["Bytes"])
73
 
74
 
75
+ def format_prompt(message, history):
76
+ '''
77
+ messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
78
+ for interaction in history:
79
+ messages.append({"role": "user", "content": interaction[0]})
80
+ messages.append({"role": "assistant", "content": interaction[1]})
 
 
81
  messages.append({"role": "user", "content": message})
82
+ prompt = tokenizer.apply_chat_template(
83
+ messages, tokenize=False, add_generation_prompt=True
84
+ )
85
+ '''
86
 
87
+ messages = [
88
+ {"role": "user", "content": "Can you tell me an interesting fact about AWS?"},]
89
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
90
+
91
+ return prompt
92
 
93
+
94
+ def generate(
95
+ prompt,
96
+ history,
 
 
 
 
97
  ):
98
+ formatted_prompt = format_prompt(prompt, history)
99
+ check_input_token_length(formatted_prompt)
 
 
 
 
 
 
 
 
100
 
101
+ request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
 
 
 
 
 
102
 
103
+
104
+ resp = smr.invoke_endpoint_with_response_stream(
105
+ EndpointName=endpoint_name,
106
+ Body=json.dumps(request),
107
+ ContentType="application/json",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
+ output = ""
110
+ for c in LineIterator(resp["Body"]):
111
+ c = c.decode("utf-8")
112
+ if c.startswith("data:"):
113
+ chunk = json.loads(c.lstrip("data:").rstrip("/n"))
114
+ if chunk["token"]["special"]:
115
+ continue
116
+ if chunk["token"]["text"] in request["parameters"]["stop"]:
117
+ break
118
+ output += chunk["token"]["text"]
119
+ for stop_str in request["parameters"]["stop"]:
120
+ if output.endswith(stop_str):
121
+ output = output[: -len(stop_str)]
122
+ output = output.rstrip()
123
+ yield output
124
+
125
+ yield output
126
+ return output
127
+
128
+
129
+ def check_input_token_length(prompt: str) -> None:
130
+ input_token_length = len(tokenizer(prompt)["input_ids"])
131
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
132
+ raise gr.Error(
133
+ f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
134
+ )
135
+
136
+
137
+ theme = gr.themes.Monochrome(
138
+ primary_hue="indigo",
139
+ secondary_hue="blue",
140
+ neutral_hue="slate",
141
+ radius_size=gr.themes.sizes.radius_sm,
142
+ font=[
143
+ gr.themes.GoogleFont("Open Sans"),
144
+ "ui-sans-serif",
145
+ "system-ui",
146
+ "sans-serif",
147
+ ],
148
+ )
149
+
150
+
151
+ demo = gr.ChatInterface(
152
+ generate,
153
+ chatbot=gr.Chatbot(layout="panel"),
154
+ theme=theme,
155
+ )
156
 
157
+ demo.queue(concurrency_count=5).launch(share=False)