DiamondYin commited on
Commit
3ab2099
1 Parent(s): fbd10e8
Files changed (3) hide show
  1. app.py +7 -1
  2. app_utils.py +11 -2
  3. testlinkspark.py +142 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import openai
4
  import time
5
  import gradio as gr
6
  from threading import Thread #线程 用于定时器
 
7
 
8
  from assets.char_poses_base64 import ( #角色动作
9
  CHAR_IDLE_HTML, CHAR_THINKING_HTML, CHAR_TALKING_HTML)
@@ -77,7 +78,12 @@ def get_response(history, audio_input):
77
  if question.lower().strip() == 'hi':
78
  question = 'hello'
79
 
80
- answer = conv_model.run(question)
 
 
 
 
 
81
  LOGGER.info("\ndocument_response: %s", answer)
82
  print('\ndocument_response:', answer)
83
 
 
4
  import time
5
  import gradio as gr
6
  from threading import Thread #线程 用于定时器
7
+ import testlinkspark
8
 
9
  from assets.char_poses_base64 import ( #角色动作
10
  CHAR_IDLE_HTML, CHAR_THINKING_HTML, CHAR_TALKING_HTML)
 
78
  if question.lower().strip() == 'hi':
79
  question = 'hello'
80
 
81
+ #answer = conv_model.run(question)
82
+ answer = testlinkspark.main(appid="d2ff57e0",
83
+ api_secret="YjlmNDdkYjFmMGMzYjc5MmJiODFjN2Fi",
84
+ api_key="07963fbc530a42f4ad223517decfd5fe",
85
+ gpt_url="ws://spark-api.xf-yun.com/v1.1/chat",
86
+ question= question)
87
  LOGGER.info("\ndocument_response: %s", answer)
88
  print('\ndocument_response:', answer)
89
 
app_utils.py CHANGED
@@ -41,8 +41,17 @@ def initialize_knowledge_base():
41
  char_text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
42
  doc_texts = char_text_splitter.split_documents(docs)
43
 
44
- openAI_embeddings = OpenAIEmbeddings()
45
- vStore = Chroma.from_documents(doc_texts, openAI_embeddings)
 
 
 
 
 
 
 
 
 
46
 
47
  conv_model = RetrievalQA.from_chain_type(
48
  llm=OpenAI(),
 
41
  char_text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
42
  doc_texts = char_text_splitter.split_documents(docs)
43
 
44
+ #调整embeddings
45
+
46
+ from langchain.embeddings import HuggingFaceEmbeddings
47
+ embedding = HuggingFaceInstructEmbeddings(model_name = "hkunlp/instructor-large")
48
+
49
+ #openAI_embeddings = OpenAIEmbeddings()
50
+ #vStore = Chroma.from_documents(doc_texts, openAI_embeddings)
51
+ #调整模型
52
+
53
+
54
+
55
 
56
  conv_model = RetrievalQA.from_chain_type(
57
  llm=OpenAI(),
testlinkspark.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import _thread as thread
2
+ import base64
3
+ import datetime
4
+ import hashlib
5
+ import hmac
6
+ import json
7
+ from urllib.parse import urlparse
8
+ import ssl
9
+ from datetime import datetime
10
+ from time import mktime
11
+ from urllib.parse import urlencode
12
+ from wsgiref.handlers import format_date_time
13
+ # websocket-client
14
+ import websocket
15
+
16
+ class Ws_Param(object):
17
+ # 初始化
18
+ def __init__(self, APPID, APIKey, APISecret, gpt_url):
19
+ self.APPID = APPID
20
+ self.APIKey = APIKey
21
+ self.APISecret = APISecret
22
+ self.host = urlparse(gpt_url).netloc
23
+ self.path = urlparse(gpt_url).path
24
+ self.gpt_url = gpt_url
25
+
26
+ # 生成url
27
+ def create_url(self):
28
+ # 生成RFC1123格式的时间戳
29
+ now = datetime.now()
30
+ date = format_date_time(mktime(now.timetuple()))
31
+
32
+ # 拼接字符串
33
+ signature_origin = "host: " + self.host + "\n"
34
+ signature_origin += "date: " + date + "\n"
35
+ signature_origin += "GET " + self.path + " HTTP/1.1"
36
+
37
+ # 进行hmac-sha256进行加密
38
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
39
+ digestmod=hashlib.sha256).digest()
40
+
41
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
42
+
43
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
44
+
45
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
46
+
47
+ # 将请求的鉴权参数组合为字典
48
+ v = {
49
+ "authorization": authorization,
50
+ "date": date,
51
+ "host": self.host
52
+ }
53
+ # 拼接鉴权参数,生成url
54
+ url = self.gpt_url + '?' + urlencode(v)
55
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
56
+ return url
57
+
58
+
59
+ # 收到websocket错误的处理
60
+ def on_error(ws, error):
61
+ print("### error:", error)
62
+
63
+
64
+ # 收到websocket关闭的处理
65
+ def on_close(ws, status_code, reason):
66
+ print("")
67
+
68
+
69
+
70
+
71
+ # 收到websocket连接建立的处理
72
+ def on_open(ws):
73
+ thread.start_new_thread(run, (ws,))
74
+
75
+
76
+ def run(ws, *args):
77
+ data = json.dumps(gen_params(appid=ws.appid, question=ws.question))
78
+ ws.send(data)
79
+
80
+
81
+ # 收到websocket消息的处理
82
+ def on_message(ws, message):
83
+ # print(message)
84
+ data = json.loads(message)
85
+ code = data['header']['code']
86
+ if code != 0:
87
+ print(f'请求错误: {code}, {data}')
88
+ ws.close()
89
+ else:
90
+ choices = data["payload"]["choices"]
91
+ status = choices["status"]
92
+ content = choices["text"][0]["content"]
93
+ print(content, end='')
94
+ if status == 2:
95
+ ws.close()
96
+
97
+
98
+ def gen_params(appid, question):
99
+ """
100
+ 通过appid和用户的提问来生成请参数
101
+ """
102
+ data = {
103
+ "header": {
104
+ "app_id": appid,
105
+ "uid": "1234"
106
+ },
107
+ "parameter": {
108
+ "chat": {
109
+ "domain": "general",
110
+ "random_threshold": 0.5,
111
+ "max_tokens": 2048,
112
+ "auditing": "default"
113
+ }
114
+ },
115
+ "payload": {
116
+ "message": {
117
+ "text": [
118
+ {"role": "user", "content": question}
119
+ ]
120
+ }
121
+ }
122
+ }
123
+ return data
124
+
125
+
126
+ def main(appid, api_key, api_secret, gpt_url, question):
127
+ wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
128
+ websocket.enableTrace(False)
129
+ wsUrl = wsParam.create_url()
130
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
131
+ ws.appid = appid
132
+ ws.question = question
133
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
134
+
135
+
136
+ if __name__ == "__main__":
137
+ # 测试时候在此处正确填写相关信息即可运行
138
+ main(appid="d2ff57e0",
139
+ api_secret="YjlmNDdkYjFmMGMzYjc5MmJiODFjN2Fi",
140
+ api_key="07963fbc530a42f4ad223517decfd5fe",
141
+ gpt_url="ws://spark-api.xf-yun.com/v1.1/chat",
142
+ question="鲁迅和周树人是同一个人吗?")