llama3-8b-mi-amd / gateway.py
Lohia, Aditya
Updated Spaces
a9409d4
import json
import requests
def check_server_health(cloud_gateway_api: str):
"""
Use the appropriate API endpoint to check the server health.
Args:
cloud_gateway_api: API endpoint to probe.
Returns:
True if server is active, false otherwise.
"""
try:
response = requests.get(cloud_gateway_api + "/health")
if response.status_code == 200:
return True
except requests.ConnectionError:
print("Failed to establish connection to the server.")
return False
def request_generation(message: str,
system_prompt: str,
cloud_gateway_api: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2, ):
"""
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
token-by-token generation from LLM.
Args:
message: prompt from the user.
system_prompt: system prompt to append.
cloud_gateway_api (str): API endpoint to send the request.
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
temperature: the value used to module the next token probabilities.
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
or higher are kept for generation.
top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
Returns:
"""
payload = {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"stream": True # Enable streaming
}
with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response:
for chunk in response.iter_lines():
if chunk:
# Convert the chunk from bytes to a string and then parse it as json
chunk_str = chunk.decode('utf-8')
# Remove the `data: ` prefix from the chunk if it exists
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
# Skip empty chunks
if chunk_str.strip() == "[DONE]":
break
# Parse the chunk into a JSON object
try:
chunk_json = json.loads(chunk_str)
# Extract the "content" field from the choices
content = chunk_json["choices"][0]["delta"].get("content", "")
# Print the generated content as it's streamed
if content:
yield content
except json.JSONDecodeError:
# Handle any potential errors in decoding
continue