gpt-4 / conversations /conversation_connector.py
Hansimov's picture
:gem: [Feature] Support system prompt via previousMessages key
16c8074
raw
history blame
5.66 kB
import aiohttp
import json
import urllib
from networks import (
ChathubRequestPayloadConstructor,
ConversationRequestHeadersConstructor,
MessageParser,
OpenaiStreamOutputer,
)
from conversations import ConversationStyle
from utils.logger import logger
from utils.enver import enver
class ConversationConnector:
"""
Input params:
- `sec_access_token`, `client_id`, `conversation_id`
- Generated by `ConversationCreator`
- `invocation_id` (int):
- For 1st request, this value must be `0`.
- For all requests after, any integer is valid.
- To make it simple, use `1` for all requests after the 1st one.
"""
def __init__(
self,
conversation_style: str = ConversationStyle.PRECISE.value,
sec_access_token: str = "",
client_id: str = "",
conversation_id: str = "",
invocation_id: int = 0,
cookies={},
):
conversation_style_enum_values = [
style.value for style in ConversationStyle.__members__.values()
]
if conversation_style.lower() not in conversation_style_enum_values:
self.conversation_style = ConversationStyle.PRECISE.value
else:
self.conversation_style = conversation_style.lower()
print(f"Model: [{self.conversation_style}]")
self.sec_access_token = sec_access_token
self.client_id = client_id
self.conversation_id = conversation_id
self.invocation_id = invocation_id
self.cookies = cookies
async def wss_send(self, message):
serialized_websocket_message = json.dumps(message, ensure_ascii=False) + "\x1e"
await self.wss.send_str(serialized_websocket_message)
async def init_handshake(self):
await self.wss_send({"protocol": "json", "version": 1})
await self.wss.receive_str()
await self.wss_send({"type": 6})
async def connect(self):
self.quotelized_sec_access_token = urllib.parse.quote(self.sec_access_token)
self.ws_url = (
f"wss://sydney.bing.com/sydney/ChatHub"
f"?sec_access_token={self.quotelized_sec_access_token}"
)
self.aiohttp_session = aiohttp.ClientSession(cookies=self.cookies)
headers_constructor = ConversationRequestHeadersConstructor()
enver.set_envs(proxies=True)
self.wss = await self.aiohttp_session.ws_connect(
self.ws_url,
headers=headers_constructor.request_headers,
proxy=enver.proxy,
)
await self.init_handshake()
async def send_chathub_request(self, prompt: str, system_prompt: str = None):
payload_constructor = ChathubRequestPayloadConstructor(
prompt=prompt,
conversation_style=self.conversation_style,
client_id=self.client_id,
conversation_id=self.conversation_id,
invocation_id=self.invocation_id,
system_prompt=system_prompt,
)
self.connect_request_payload = payload_constructor.request_payload
await self.wss_send(self.connect_request_payload)
async def stream_chat(
self, prompt: str = "", system_prompt: str = None, yield_output=False
):
await self.connect()
await self.send_chathub_request(prompt=prompt, system_prompt=system_prompt)
message_parser = MessageParser(outputer=OpenaiStreamOutputer())
has_output_role_message = False
if yield_output and not has_output_role_message:
has_output_role_message = True
yield message_parser.outputer.output(content="", content_type="Role")
while not self.wss.closed:
response_lines_str = await self.wss.receive_str()
if isinstance(response_lines_str, str):
response_lines = response_lines_str.split("\x1e")
else:
continue
for line in response_lines:
if not line:
continue
data = json.loads(line)
# Stream: Meaningful Messages
if data.get("type") == 1:
if yield_output:
output = message_parser.parse(data, return_output=True)
if isinstance(output, list):
for item in output:
yield item
else:
if output:
yield output
else:
message_parser.parse(data)
# Stream: List of all messages in the whole conversation
elif data.get("type") == 2:
if data.get("item"):
# item = data.get("item")
# logger.note("\n[Saving chat messages ...]")
pass
# Stream: End of Conversation
elif data.get("type") == 3:
finished_str = "\n[Finished]"
logger.success(finished_str)
self.invocation_id += 1
await self.wss.close()
await self.aiohttp_session.close()
if yield_output:
yield message_parser.outputer.output(
content=finished_str, content_type="Finished"
)
break
# Stream: Heartbeat Signal
elif data.get("type") == 6:
continue
# Stream: Not Implemented
else:
continue