Spaces:
Running
Running
import json | |
import requests | |
def check_server_health(cloud_gateway_api: str): | |
""" | |
Use the appropriate API endpoint to check the server health. | |
Args: | |
cloud_gateway_api: API endpoint to probe. | |
Returns: | |
True if server is active, false otherwise. | |
""" | |
try: | |
response = requests.get(cloud_gateway_api + "/health") | |
if response.status_code == 200: | |
return True | |
except requests.ConnectionError: | |
print("Failed to establish connection to the server.") | |
return False | |
def request_generation(message: str, | |
system_prompt: str, | |
cloud_gateway_api: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, ): | |
""" | |
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize | |
token-by-token generation from LLM. | |
Args: | |
message: prompt from the user. | |
system_prompt: system prompt to append. | |
cloud_gateway_api (str): API endpoint to send the request. | |
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. | |
temperature: the value used to module the next token probabilities. | |
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p | |
or higher are kept for generation. | |
top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. | |
Returns: | |
""" | |
payload = { | |
"model": "meta-llama/Meta-Llama-3-8B-Instruct", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message} | |
], | |
"max_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"top_k": top_k, | |
"stream": True # Enable streaming | |
} | |
with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response: | |
for chunk in response.iter_lines(): | |
if chunk: | |
# Convert the chunk from bytes to a string and then parse it as json | |
chunk_str = chunk.decode('utf-8') | |
# Remove the `data: ` prefix from the chunk if it exists | |
if chunk_str.startswith("data: "): | |
chunk_str = chunk_str[len("data: "):] | |
# Skip empty chunks | |
if chunk_str.strip() == "[DONE]": | |
break | |
# Parse the chunk into a JSON object | |
try: | |
chunk_json = json.loads(chunk_str) | |
# Extract the "content" field from the choices | |
content = chunk_json["choices"][0]["delta"].get("content", "") | |
# Print the generated content as it's streamed | |
if content: | |
yield content | |
except json.JSONDecodeError: | |
# Handle any potential errors in decoding | |
continue | |