Spaces:
Running
Running
File size: 3,398 Bytes
a9409d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import json
import requests
def check_server_health(cloud_gateway_api: str):
"""
Use the appropriate API endpoint to check the server health.
Args:
cloud_gateway_api: API endpoint to probe.
Returns:
True if server is active, false otherwise.
"""
try:
response = requests.get(cloud_gateway_api + "/health")
if response.status_code == 200:
return True
except requests.ConnectionError:
print("Failed to establish connection to the server.")
return False
def request_generation(message: str,
system_prompt: str,
cloud_gateway_api: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2, ):
"""
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
token-by-token generation from LLM.
Args:
message: prompt from the user.
system_prompt: system prompt to append.
cloud_gateway_api (str): API endpoint to send the request.
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
temperature: the value used to module the next token probabilities.
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
or higher are kept for generation.
top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
Returns:
"""
payload = {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"stream": True # Enable streaming
}
with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response:
for chunk in response.iter_lines():
if chunk:
# Convert the chunk from bytes to a string and then parse it as json
chunk_str = chunk.decode('utf-8')
# Remove the `data: ` prefix from the chunk if it exists
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
# Skip empty chunks
if chunk_str.strip() == "[DONE]":
break
# Parse the chunk into a JSON object
try:
chunk_json = json.loads(chunk_str)
# Extract the "content" field from the choices
content = chunk_json["choices"][0]["delta"].get("content", "")
# Print the generated content as it's streamed
if content:
yield content
except json.JSONDecodeError:
# Handle any potential errors in decoding
continue
|