Liyonghui commited on
Commit
33915c0
·
1 Parent(s): 8acf5b2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +209 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat_bot.py
2
+
3
+ import streamlit as st
4
+ from streamlit_chat import message
5
+ import _thread as thread
6
+ import base64
7
+ import datetime
8
+ import hashlib
9
+ import hmac
10
+ import json
11
+ from urllib.parse import urlparse
12
+ import ssl
13
+ from datetime import datetime
14
+ from time import mktime
15
+ from urllib.parse import urlencode
16
+ from wsgiref.handlers import format_date_time
17
+
18
+ import websocket # 使用websocket_client
19
+
20
+ import sys, io
21
+
22
+ answer = ""
23
+
24
+
25
+ class Ws_Param(object):
26
+ # 初始化
27
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
28
+ self.APPID = APPID
29
+ self.APIKey = APIKey
30
+ self.APISecret = APISecret
31
+ self.host = urlparse(Spark_url).netloc
32
+ self.path = urlparse(Spark_url).path
33
+ self.Spark_url = Spark_url
34
+
35
+ # 生成url
36
+ def create_url(self):
37
+ # 生成RFC1123格式的时间戳
38
+ now = datetime.now()
39
+ date = format_date_time(mktime(now.timetuple()))
40
+
41
+ # 拼接字符串
42
+ signature_origin = "host: " + self.host + "\n"
43
+ signature_origin += "date: " + date + "\n"
44
+ signature_origin += "GET " + self.path + " HTTP/1.1"
45
+
46
+ # 进行hmac-sha256进行加密
47
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
48
+ digestmod=hashlib.sha256).digest()
49
+
50
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
51
+
52
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
53
+
54
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
55
+
56
+ # 将请求的鉴权参数组合为字典
57
+ v = {
58
+ "authorization": authorization,
59
+ "date": date,
60
+ "host": self.host
61
+ }
62
+ # 拼接鉴权参数,生成url
63
+ url = self.Spark_url + '?' + urlencode(v)
64
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
65
+ return url
66
+
67
+
68
+ # 收到websocket错误的处理
69
+ def on_error(ws, error):
70
+ print("### error:", error)
71
+
72
+
73
+ # 收到websocket关闭的处理
74
+ def on_close(ws, one, two):
75
+ print(" ")
76
+
77
+
78
+ # 收到websocket连接建立的处理
79
+ def on_open(ws):
80
+ thread.start_new_thread(run, (ws,))
81
+
82
+
83
+ def run(ws, *args):
84
+ data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question))
85
+ ws.send(data)
86
+
87
+
88
+ # 收到websocket消息的处理
89
+ def on_message(ws, message):
90
+ # print(message)
91
+ data = json.loads(message)
92
+ code = data['header']['code']
93
+ if code != 0:
94
+ print(f'请求错误: {code}, {data}')
95
+ ws.close()
96
+ else:
97
+ choices = data["payload"]["choices"]
98
+ status = choices["status"]
99
+ content = choices["text"][0]["content"]
100
+ #print(content, end="")
101
+ global answer
102
+ answer += content
103
+ # print(1)
104
+ if status == 2:
105
+ ws.close()
106
+
107
+
108
+ def gen_params(appid, domain, question):
109
+ """
110
+ 通过appid和用户的提问来生成请参数
111
+ """
112
+ data = {
113
+ "header": {
114
+ "app_id": appid,
115
+ "uid": "1234"
116
+ },
117
+ "parameter": {
118
+ "chat": {
119
+ "domain": domain,
120
+ "random_threshold": 0.5,
121
+ "max_tokens": 2048,
122
+ "auditing": "default"
123
+ }
124
+ },
125
+ "payload": {
126
+ "message": {
127
+ "text": question
128
+ }
129
+ }
130
+ }
131
+ return data
132
+
133
+
134
+ def main(appid, api_key, api_secret, Spark_url, domain, question):
135
+ # print("星火:")
136
+ wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
137
+ websocket.enableTrace(False)
138
+ wsUrl = wsParam.create_url()
139
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
140
+ ws.appid = appid
141
+ ws.question = question
142
+ ws.domain = domain
143
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
144
+
145
+
146
+ # 以下密钥信息从控制台获取
147
+ appid = "7e203b6c" # 填写控制台中获取的 APPID 信息
148
+ api_secret = "MjU0MGUwZDI1MDk4MDVjZTg1ZjU0YjZi" # 填写控制台中获取的 APISecret 信息
149
+ api_key = "ef969b3800d8d3a367c06729844f6ab4" # 填写控制台中获取的 APIKey 信息
150
+
151
+ # 用于配置大模型版本,默认“general/generalv2”
152
+ domain = "general" # v1.5版本
153
+ # domain = "generalv2" # v2.0版本
154
+ # 云端环境的服务地址
155
+ Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
156
+ # Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
157
+
158
+
159
+ text = []
160
+
161
+
162
+ # length = 0
163
+
164
+ def getText(role, content):
165
+ jsoncon = {}
166
+ jsoncon["role"] = role
167
+ jsoncon["content"] = content
168
+ text.append(jsoncon)
169
+ return text
170
+
171
+
172
+ def getlength(text):
173
+ length = 0
174
+ for content in text:
175
+ temp = content["content"]
176
+ leng = len(temp)
177
+ length += leng
178
+ return length
179
+
180
+
181
+ def checklen(text):
182
+ while (getlength(text) > 8000):
183
+ del text[0]
184
+ return text
185
+
186
+ def getanswerByXunFei(t):
187
+ text.clear
188
+ question = checklen(getText("user", t))
189
+ ans = ""
190
+ main(appid, api_key, api_secret, Spark_url, domain, question)
191
+ getText("assistant", ans)
192
+ return answer
193
+
194
+ st.markdown("我是讯飞聊天机器人,我可以回答您的任何问题")
195
+ if 'generated' not in st.session_state:
196
+ st.session_state['generated'] = []
197
+ if 'past' not in st.session_state:
198
+ st.session_state['past'] = []
199
+ user_input=st.text_input("请输入您的问题:",key='input')
200
+ if user_input:
201
+ output=getanswerByXunFei(user_input)
202
+ st.session_state['past'].append(user_input)
203
+ st.session_state['generated'].append(output)
204
+ if st.session_state['generated']:
205
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
206
+ message(st.session_state["generated"][i], key=str(i))
207
+ message(st.session_state['past'][i],
208
+ is_user=True,
209
+ key=str(i)+'_user')
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ websocket