executor / app /jupyter_api.py
eggie5-adyen's picture
typo
04eb088
raw
history blame
4.21 kB
import uuid
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from jupyter_client import KernelManager
from threading import Lock
import asyncio
from typing import Optional
app = FastAPI()
# A dictionary to store kernel sessions using session_id as the key
kernel_sessions = {}
# Lock for thread-safe access to the kernel_sessions dictionary
sessions_lock = Lock()
class CodeExecutionRequest(BaseModel):
code: str
restart: Optional[bool] = False #backward compatiblity
class CreateSessionResponse(BaseModel):
session_id: str
@app.post("/create_session", response_model=CreateSessionResponse)
async def create_session():
"""
Creates a new Jupyter kernel session and returns the session_id.
"""
session_id = str(uuid.uuid4()) # Generate a unique session ID
with sessions_lock:
# Create a new kernel manager and start a kernel
km = KernelManager()
km.kernel_name = 'python3'
km.start_kernel()
# Create a client for interacting with the kernel
kc = km.client()
kc.start_channels()
# Store the kernel manager and client in the session dictionary
kernel_sessions[session_id] = {'km': km, 'kc': kc}
return CreateSessionResponse(session_id=session_id)
@app.post("/execute/{session_id}")
async def execute_code(session_id: str, request: CodeExecutionRequest):
"""
Executes code in the specified session's Jupyter kernel.
"""
with sessions_lock:
session = kernel_sessions.get(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
kc = session['kc']
# Asynchronous code execution in the kernel
loop = asyncio.get_running_loop()
exec_id = uuid.uuid4().hex
# This function will run in a separate thread to avoid blocking
def run_code():
kc.execute(request.code)
# Collect output messages from the iopub channel
output = []
while True:
try:
msg = kc.get_iopub_msg(timeout=2)
msg_type = msg['msg_type']
# Process different types of iopub messages
if msg_type == 'stream':
output.append(msg['content']['text'])
elif msg_type == 'error':
# Include traceback if there's an error
output.extend(msg['content']['traceback'])
elif msg_type in ['execute_result', 'display_data']:
# Capture the output result if it exists
output.append(msg['content']['data'].get('text/plain', ''))
# Exit when execution completes
if msg_type == 'status' and msg['content']['execution_state'] == 'idle':
break
except Exception as e:
output.append(f"Error capturing output: {str(e)}")
break
return "\n".join(output)
try:
# Execute the code and await the result asynchronously
output = await loop.run_in_executor(None, run_code)
return {"status": "success", "output": output}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/shutdown/{session_id}")
async def shutdown_session(session_id: str):
"""
Shuts down the Jupyter kernel associated with the specified session_id.
"""
with sessions_lock:
session = kernel_sessions.pop(session_id, None)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Stop the kernel and clean up resources
try:
session['kc'].stop_channels()
session['km'].shutdown_kernel()
return {"status": "success", "message": "Session terminated"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/list_sessions")
async def list_sessions():
"""
Lists all active Jupyter kernel sessions.
"""
with sessions_lock:
# Prepare a list of session details
sessions_list = [{"session_id": sid} for sid in kernel_sessions.keys()]
return {"status": "success", "sessions": sessions_list}