File size: 3,398 Bytes
a9409d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import json
import requests


def check_server_health(cloud_gateway_api: str):
    """
    Use the appropriate API endpoint to check the server health.
    Args:
        cloud_gateway_api: API endpoint to probe.

    Returns:
        True if server is active, false otherwise.
    """
    try:
        response = requests.get(cloud_gateway_api + "/health")
        if response.status_code == 200:
            return True
    except requests.ConnectionError:
        print("Failed to establish connection to the server.")

    return False


def request_generation(message: str,
                       system_prompt: str,
                       cloud_gateway_api: str,
                       max_new_tokens: int = 1024,
                       temperature: float = 0.6,
                       top_p: float = 0.9,
                       top_k: int = 50,
                       repetition_penalty: float = 1.2, ):
    """
    Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
    token-by-token generation from LLM.

    Args:
        message: prompt from the user.
        system_prompt: system prompt to append.
        cloud_gateway_api (str): API endpoint to send the request.
        max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
        temperature: the value used to module the next token probabilities.
        top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
                or higher are kept for generation.
        top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.

    Returns:

    """

    payload = {
        "model": "meta-llama/Meta-Llama-3-8B-Instruct",
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message}
        ],
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "stream": True  # Enable streaming
    }

    with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response:
        for chunk in response.iter_lines():
            if chunk:
                # Convert the chunk from bytes to a string and then parse it as json
                chunk_str = chunk.decode('utf-8')

                # Remove the `data: ` prefix from the chunk if it exists
                if chunk_str.startswith("data: "):
                    chunk_str = chunk_str[len("data: "):]

                # Skip empty chunks
                if chunk_str.strip() == "[DONE]":
                    break

                # Parse the chunk into a JSON object
                try:
                    chunk_json = json.loads(chunk_str)
                    # Extract the "content" field from the choices
                    content = chunk_json["choices"][0]["delta"].get("content", "")

                    # Print the generated content as it's streamed
                    if content:
                        yield content
                except json.JSONDecodeError:
                    # Handle any potential errors in decoding
                    continue