Spaces:
Sleeping
Sleeping
""" | |
License: | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
In no event shall the authors or copyright holders be liable | |
for any claim, damages or other liability, whether in an action of contract,otherwise, | |
arising from, out of or in connection with the software or the use or | |
other dealings in the software. | |
Copyright (c) 2024 pi19404. All rights reserved. | |
Authors: | |
pi19404 <pi19404@gmail.com> | |
""" | |
""" | |
Gradio Interface for Shield Gemma LLM Evaluator | |
This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator. | |
It allows users to input JSON data and select various options to evaluate the content | |
for policy violations. | |
Functions: | |
my_inference_function: The main inference function to process input data and return results. | |
""" | |
import gradio as gr | |
from gradio_client import Client | |
import json | |
import threading | |
import os | |
from collections import OrderedDict | |
import httpx | |
API_TOKEN=os.getenv("API_TOKEN") | |
lock = threading.Lock() | |
#client = Client("pi19404/ai-worker",hf_token=API_TOKEN) | |
# Create an OrderedDict to store clients, limited to 15 entries | |
client_cache = OrderedDict() | |
MAX_CACHE_SIZE = 15 | |
def my_inference_function(client,input_data, output_data,mode, max_length, max_new_tokens, model_size): | |
""" | |
The main inference function to process input data and return results. | |
Args: | |
input_data (str or dict): The input data in JSON format. | |
mode (str): The mode of operation ("scoring" or "generative"). | |
max_length (int): The maximum length of the input prompt. | |
max_new_tokens (int): The maximum number of new tokens to generate. | |
model_size (str): The size of the model to be used. | |
Returns: | |
str: The output data in JSON format. | |
""" | |
with lock: | |
try: | |
result = client[0].predict( | |
input_data=input_data, | |
output_data=output_data, | |
mode=mode, | |
max_length=max_length, | |
max_new_tokens=max_new_tokens, | |
model_size=model_size, | |
api_name="/my_inference_function" | |
) | |
print(result) | |
print("entering return",result) | |
return result # Pretty-print the JSON | |
except json.JSONDecodeError: | |
return json.dumps({"error": "Invalid JSON input"}) | |
except KeyError: | |
return json.dumps({"error": "Missing 'input' key in JSON"}) | |
except ValueError as e: | |
return json.dumps({"error": str(e)}) | |
def wake_up_space_with_retries(space_url, token, retries=5, wait_time=10): | |
""" | |
Attempt to wake up the Hugging Face Space with retries. | |
Retries a number of times in case of a delay due to the Space waking up. | |
:param space_url: The URL of the Hugging Face Space. | |
:param token: The Hugging Face API token. | |
:param retries: Number of retries if the Space is sleeping. | |
:param wait_time: Time to wait between retries (in seconds). | |
""" | |
for attempt in range(retries): | |
try: | |
print(f"Attempt {attempt + 1} to wake up the Space...") | |
# Initialize the Gradio Client | |
client = Client(space_url, hf_token=token, timeout=httpx.Timeout(30.0)) # 30-second timeout | |
my_inference_function(client,"test input","",scoring,10,10,"2B") | |
# Make a prediction or call to wake the Space | |
#result = client.predict("<your_input>") # Replace with actual inputs | |
print("Space is awake and ready!") | |
return client | |
except httpx.ReadTimeout: | |
print(f"Request timed out on attempt {attempt + 1}. Retrying in {wait_time} seconds...") | |
time.sleep(wait_time) | |
except Exception as e: | |
print(f"An error occurred on attempt {attempt + 1}: {e}") | |
# Wait before retrying | |
if attempt < retries - 1: | |
print(f"Waiting for {wait_time} seconds before retrying...") | |
print("Space is still not active after multiple attempts.") | |
return None | |
#default_client=Client("pi19404/ai-worker", hf_token=API_TOKEN) | |
default_client=wake_up_space_with_retries("pi19404/ai-worker",API_TOKEN) | |
def get_client_for_ip(ip_address,x_ip_token): | |
""" | |
Retrieve or create a client for the given IP address. | |
This function implements a caching mechanism to store up to MAX_CACHE_SIZE clients. | |
If a client for the given IP exists in the cache, it's returned and moved to the end | |
of the cache (marking it as most recently used). If not, a new client is created, | |
added to the cache, and the least recently used client is removed if the cache is full. | |
Args: | |
ip_address (str): The IP address of the client. | |
x_ip_token (str): The X-IP-Token header value for the client. | |
Returns: | |
Client: A Gradio client instance for the given IP address. | |
""" | |
if x_ip_token is None: | |
x_ip_token=ip_address | |
#print("ipaddress is ",x_ip_token) | |
if x_ip_token is None: | |
new_client=default_client | |
else: | |
if x_ip_token in client_cache: | |
# Move the accessed item to the end (most recently used) | |
client_cache.move_to_end(x_ip_token) | |
return client_cache[x_ip_token] | |
# Create a new client | |
new_client = Client("pi19404/ai-worker", hf_token=API_TOKEN, headers={"X-IP-Token": x_ip_token}) | |
# Add to cache, removing oldest if necessary | |
if len(client_cache) >= MAX_CACHE_SIZE: | |
client_cache.popitem(last=False) | |
client_cache[x_ip_token] = new_client | |
return new_client | |
def set_client_for_session(request: gr.Request): | |
""" | |
Set up a client for the current session and collect request headers. | |
This function is called when a new session is initiated. It retrieves or creates | |
a client for the session's IP address and collects all request headers for debugging. | |
Args: | |
request (gr.Request): The Gradio request object for the current session. | |
Returns: | |
tuple: A tuple containing: | |
- Client: The Gradio client instance for the session. | |
- str: A JSON string of all request headers. | |
""" | |
# Collect all headers in a dictionary | |
all_headers = {header: value for header, value in request.headers.items()} | |
# Print headers to console | |
print("All request headers:") | |
print(json.dumps(all_headers, indent=2)) | |
x_ip_token = request.headers.get('x-ip-token',None) | |
ip_address = request.client.host | |
print("ip address is ",ip_address) | |
client = get_client_for_ip(ip_address,x_ip_token) | |
# Return both the client and the headers | |
return client, json.dumps(all_headers, indent=2) | |
# The "gradio/text-to-image" space is a ZeroGPU space | |
with gr.Blocks() as demo: | |
""" | |
Main Gradio interface setup. | |
This block sets up the Gradio interface, including: | |
- A State component to store the client for the session. | |
- A JSON component to display request headers for debugging. | |
- Other UI components (not shown in this snippet). | |
- A load event that calls set_client_for_session when the interface is loaded. | |
""" | |
gr.Markdown("## LLM Safety Evaluation") | |
client = gr.State() | |
with gr.Tab("ShieldGemma2"): | |
input_text = gr.Textbox(label="Input Text") | |
output_text = gr.Textbox( | |
label="Response Text", | |
lines=5, | |
max_lines=10, | |
show_copy_button=True, | |
elem_classes=["wrap-text"] | |
) | |
mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode") | |
max_length_input = gr.Number(label="Max Length", value=150) | |
max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024) | |
model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
response_text = gr.Textbox( | |
label="Output Text", | |
lines=10, | |
max_lines=20, | |
show_copy_button=True, | |
elem_classes=["wrap-text"] | |
) | |
text_button = gr.Button("Submit") | |
text_button.click(fn=my_inference_function, inputs=[client,input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text) | |
# with gr.Tab("API Input"): | |
# api_input = gr.JSON(label="Input JSON") | |
# mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode") | |
# max_length_input_api = gr.Number(label="Max Length", value=150) | |
# max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None) | |
# model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
# api_output = gr.JSON(label="Output JSON") | |
# api_button = gr.Button("Submit") | |
# api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output) | |
demo.load(set_client_for_session,None,client) | |
demo.launch(share=True) | |