from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import httpx import logging from typing import Optional, Dict, Any logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) templates = Jinja2Templates(directory="templates") class QueryRequest(BaseModel): prompt: str temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 500 class QueryResponse(BaseModel): response: str model: str = "llama3.2:1b-papalia" async def generate_with_retries(client: httpx.AsyncClient, data: dict, max_retries: int = 3) -> dict: for attempt in range(max_retries): try: response = await client.post( "http://localhost:11434/api/generate", json=data, timeout=90.0 ) response.raise_for_status() return response.json() except Exception as e: if attempt == max_retries - 1: raise await httpx.AsyncClient().aclose() continue @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/generate") async def generate_response(query: QueryRequest): try: async with httpx.AsyncClient(timeout=90.0) as client: result = await generate_with_retries( client, { "model": "llama3.2:1b-papalia", "prompt": query.prompt, "temperature": query.temperature, "max_tokens": query.max_tokens, "stream": False } ) return {"response": result.get("response", ""), "model": "llama3.2:1b-papalia"} except Exception as e: logger.error(f"Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): try: async with httpx.AsyncClient(timeout=5.0) as client: await client.get("http://localhost:11434") return {"status": "healthy"} except Exception as e: return {"status": "unhealthy", "error": str(e)}