|
import uuid |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from jupyter_client import KernelManager |
|
from threading import Lock |
|
import asyncio |
|
from typing import Optional |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
kernel_sessions = {} |
|
|
|
sessions_lock = Lock() |
|
|
|
class CodeExecutionRequest(BaseModel): |
|
code: str |
|
restart: Optional[bool] = False |
|
|
|
class CreateSessionResponse(BaseModel): |
|
session_id: str |
|
|
|
@app.post("/create_session", response_model=CreateSessionResponse) |
|
async def create_session(): |
|
""" |
|
Creates a new Jupyter kernel session and returns the session_id. |
|
""" |
|
session_id = str(uuid.uuid4()) |
|
|
|
with sessions_lock: |
|
|
|
km = KernelManager() |
|
km.kernel_name = 'python3' |
|
km.start_kernel() |
|
|
|
|
|
kc = km.client() |
|
kc.start_channels() |
|
|
|
|
|
kernel_sessions[session_id] = {'km': km, 'kc': kc} |
|
|
|
return CreateSessionResponse(session_id=session_id) |
|
|
|
@app.post("/execute/{session_id}") |
|
async def execute_code(session_id: str, request: CodeExecutionRequest): |
|
""" |
|
Executes code in the specified session's Jupyter kernel. |
|
""" |
|
with sessions_lock: |
|
session = kernel_sessions.get(session_id) |
|
|
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
kc = session['kc'] |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
exec_id = uuid.uuid4().hex |
|
|
|
|
|
def run_code(): |
|
kc.execute(request.code) |
|
|
|
|
|
output = [] |
|
while True: |
|
try: |
|
msg = kc.get_iopub_msg(timeout=2) |
|
msg_type = msg['msg_type'] |
|
|
|
|
|
if msg_type == 'stream': |
|
output.append(msg['content']['text']) |
|
elif msg_type == 'error': |
|
|
|
output.extend(msg['content']['traceback']) |
|
elif msg_type in ['execute_result', 'display_data']: |
|
|
|
output.append(msg['content']['data'].get('text/plain', '')) |
|
|
|
|
|
if msg_type == 'status' and msg['content']['execution_state'] == 'idle': |
|
break |
|
|
|
except Exception as e: |
|
output.append(f"Error capturing output: {str(e)}") |
|
break |
|
|
|
return "\n".join(output) |
|
|
|
try: |
|
|
|
output = await loop.run_in_executor(None, run_code) |
|
return {"status": "success", "output": output} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/shutdown/{session_id}") |
|
async def shutdown_session(session_id: str): |
|
""" |
|
Shuts down the Jupyter kernel associated with the specified session_id. |
|
""" |
|
with sessions_lock: |
|
session = kernel_sessions.pop(session_id, None) |
|
|
|
if not session: |
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
try: |
|
session['kc'].stop_channels() |
|
session['km'].shutdown_kernel() |
|
return {"status": "success", "message": "Session terminated"} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/list_sessions") |
|
async def list_sessions(): |
|
""" |
|
Lists all active Jupyter kernel sessions. |
|
""" |
|
with sessions_lock: |
|
|
|
sessions_list = [{"session_id": sid} for sid in kernel_sessions.keys()] |
|
|
|
return {"status": "success", "sessions": sessions_list} |
|
|