import logging import gradio as gr from queue import Queue import time from prometheus_client import start_http_server, Counter, Histogram, Gauge import threading import psutil import random from transformers import pipeline from sklearn.metrics import precision_score, recall_score, f1_score import requests from datasets import load_dataset import os # --- Ensure chat_log.txt exists --- log_file = "chat_log.txt" if not os.path.exists(log_file): with open(log_file, 'w') as f: pass # Just create the file # --- Logging Setup --- logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logging.debug("Logging setup complete.") # Load the model ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner") logging.debug("NER pipeline loaded.") # Load the dataset dataset = load_dataset("surrey-nlp/PLOD-filtered") logging.debug("Dataset loaded.") # --- Prometheus Metrics Setup --- REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests') REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds') ERROR_COUNT = Counter('gradio_error_count', 'Total number of errors') RESPONSE_SIZE = Histogram('gradio_response_size_bytes', 'Size of responses in bytes') CPU_USAGE = Gauge('system_cpu_usage_percent', 'System CPU usage in percent') MEM_USAGE = Gauge('system_memory_usage_percent', 'System memory usage in percent') QUEUE_LENGTH = Gauge('chat_queue_length', 'Length of the chat queue') logging.debug("Prometheus metrics setup complete.") # --- Queue and Metrics --- chat_queue = Queue() # Define chat_queue globally # --- Chat Function with Monitoring --- def chat_function(index): logging.debug("Starting chat_function") with REQUEST_LATENCY.time(): REQUEST_COUNT.inc() try: chat_queue.put(index) logging.info(f"Received index from user: {index}") # Get the example from the dataset example = dataset['train'][int(index)] tokens = example['tokens'] ground_truth_labels = example['ner_tags'] logging.info(f"Tokens: {tokens}") logging.info(f"Ground Truth Labels: {ground_truth_labels}") # Predict using the model ner_results = ner_pipeline(" ".join(tokens)) logging.debug(f"NER results: {ner_results}") detailed_response = [] model_predicted_labels = [] for result in ner_results: token = result['word'] score = result['score'] entity = result['entity'] label_id = int(entity.split('_')[-1]) # Extract numeric label from entity model_predicted_labels.append(label_id) detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}") response = "\n".join(detailed_response) logging.info(f"Generated response: {response}") response_size = len(response.encode('utf-8')) RESPONSE_SIZE.observe(response_size) time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time # Ensure the model and ground truth labels are the same length for comparison model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)] precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0) metrics_response = (f"Precision: {precision:.4f}\n" f"Recall: {recall:.4f}\n" f"F1 Score: {f1:.4f}") full_response = f"{response}\n\nMetrics:\n{metrics_response}" chat_queue.get() logging.debug("Finished processing message") return full_response except Exception as e: ERROR_COUNT.inc() logging.error(f"Error in chat processing: {e}", exc_info=True) return f"An error occurred. Please try again. Error: {e}" # Function to simulate stress test def stress_test(num_requests, index, delay): def send_chat_message(): response = requests.post("http://127.0.0.1:7860/api/predict/", json={ "data": [index], "fn_index": 0 # This might need to be updated based on your Gradio app's function index }) logging.debug(response.json()) threads = [] for _ in range(num_requests): t = threading.Thread(target=send_chat_message) t.start() threads.append(t) time.sleep(delay) # Delay between requests for t in threads: t.join() # --- Gradio Interface with Background Image and Three Windows --- with gr.Blocks(css=""" body { background-image: url("stag.jpeg"); background-size: cover; background-repeat: no-repeat; } """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image with gr.Tab("Chat"): gr.Markdown("## Chat with the Bot") index_input = gr.Textbox(label="Enter dataset index:", lines=1) output = gr.Textbox(label="Response", lines=20) chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output) chat_interface.render() with gr.Tab("Model Parameters"): model_params_display = gr.Textbox(label="Model Parameters", lines=20, interactive=False) # Display model parameters with gr.Tab("Performance Metrics"): request_count_display = gr.Number(label="Request Count", value=0) avg_latency_display = gr.Number(label="Avg. Response Time (s)", value=0) with gr.Tab("Infrastructure"): cpu_usage_display = gr.Number(label="CPU Usage (%)", value=0) mem_usage_display = gr.Number(label="Memory Usage (%)", value=0) with gr.Tab("Logs"): logs_display = gr.Textbox(label="Logs", lines=10) # Increased lines for better visibility with gr.Tab("Stress Testing"): num_requests_input = gr.Number(label="Number of Requests", value=10) index_input_stress = gr.Textbox(label="Dataset Index", value="2") delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1) stress_test_button = gr.Button("Start Stress Test") stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False) def run_stress_test(num_requests, index, delay): stress_test_status.value = "Stress test started..." try: stress_test(num_requests, index, delay) stress_test_status.value = "Stress test completed." except Exception as e: stress_test_status.value = f"Stress test failed: {e}" stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status) # --- Update Functions --- def update_metrics(request_count_display, avg_latency_display): while True: request_count = REQUEST_COUNT._value.get() latency_samples = REQUEST_LATENCY.collect()[0].samples avg_latency = sum(s.value for s in latency_samples) / len(latency_samples if latency_samples else [1]) # Avoid division by zero request_count_display.value = request_count avg_latency_display.value = round(avg_latency, 2) time.sleep(5) # Update every 5 seconds def update_usage(cpu_usage_display, mem_usage_display): while True: cpu_usage_display.value = psutil.cpu_percent() mem_usage_display.value = psutil.virtual_memory().percent CPU_USAGE.set(psutil.cpu_percent()) MEM_USAGE.set(psutil.virtual_memory().percent) time.sleep(5) def update_logs(logs_display): while True: with open(log_file, "r") as log_file_handler: logs = log_file_handler.readlines() logs_display.value = "".join(logs[-10:]) # Display last 10 lines time.sleep(1) # Update every 1 second def display_model_params(model_params_display): while True: model_params = ner_pipeline.model.config.to_dict() model_params_str = "\n".join(f"{key}: {value}" for key, value in model_params.items()) model_params_display.value = model_params_str time.sleep(10) # Update every 10 seconds def update_queue_length(): while True: QUEUE_LENGTH.set(chat_queue.qsize()) time.sleep(1) # Update every second # --- Start Threads --- threading.Thread(target=start_http_server, args=(8000,), daemon=True).start() threading.Thread(target=update_metrics, args=(request_count_display, avg_latency_display), daemon=True).start() threading.Thread(target=update_usage, args=(cpu_usage_display, mem_usage_display), daemon=True).start() threading.Thread(target=update_logs, args=(logs_display,), daemon=True).start() threading.Thread(target=display_model_params, args=(model_params_display,), daemon=True).start() threading.Thread(target=update_queue_length, daemon=True).start() # Launch the app demo.launch(share=True)