File size: 3,005 Bytes
d76943e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import requests
import json
import os

def _stream_chat_response(url, headers, payload):
    """
    Streams the chat response from the given URL with the specified headers and payload.
    
    Args:
        url (str): The URL to send the POST request to.
        headers (dict): The headers for the POST request.
        payload (dict): The payload for the POST request.
    
    Raises:
        InvalidArgument: If the payload does not have the 'stream' key.
        ConnectionError: If the request fails.
    
    Yields:
        str: The content of the streamed response.
    """
    if not payload.get('stream'):
        raise ValueError('This method can only handle stream payload')
    
    try:
        # Make the POST request
        response = requests.post(url, headers=headers, json=payload, stream=True)
        response.raise_for_status()  # Raise an error for bad status codes

        # Process the streamed response
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode('utf-8')
                DATA_PREFIX = "data: "
                if decoded_line.startswith(DATA_PREFIX):
                    decoded_line = decoded_line[len(DATA_PREFIX):]  # Remove the "data: " prefix
                    if decoded_line.strip() == "[DONE]":
                        break
                    try:
                        json_data = json.loads(decoded_line)
                        content = json_data.get('choices', [{}])[0].get('delta', {}).get('content', '')
                        if content:
                            yield content
                    except json.JSONDecodeError as e:
                        print(f"Warning: Error decoding JSON: {decoded_line}. Skipping this line.")
    except requests.RequestException as e:
        raise ConnectionError(f"Request failed: {e}") from e

def Streamer(history, **kwargs):
    """
    Streams the chat response based on the provided history and additional kwargs.
    
    Args:
        history (dict): The chat history.
        **kwargs: Additional parameters to update the payload.
    
    Yields:
        str: The content of the streamed response.
    """
    url = os.getenv('URL')
    token = os.getenv('TOKEN')
    
    if not url or not token:
        raise EnvironmentError("URL or TOKEN environment variable is not set.")
    
    headers = {
        "Authorization": f"Basic {token}",
        "Content-Type": "application/json"
    }
    payload = {
        "messages": history,
        "max_tokens": 1000,
        "stop": ["<|eot_id|>"],
        "model": "llama3-405b",
        "stream": True
    }
    payload.update(kwargs)

    for update in _stream_chat_response(url, headers, payload):
        yield update

# Example usage
if __name__ == "__main__":
    try:
        history = [{"role": "user", "content": "Tell me a joke"}]

        for content in Streamer(history):
            print(content, end='')
    except Exception as e:
        print(f"An error occurred: {e}")