Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import httpx | |
import logging | |
from typing import Optional, Dict, Any | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
templates = Jinja2Templates(directory="templates") | |
class QueryRequest(BaseModel): | |
prompt: str | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = 500 | |
class QueryResponse(BaseModel): | |
response: str | |
model: str = "llama3.2:1b-papalia" | |
async def generate_with_retries(client: httpx.AsyncClient, data: dict, max_retries: int = 3) -> dict: | |
for attempt in range(max_retries): | |
try: | |
response = await client.post( | |
"http://localhost:11434/api/generate", | |
json=data, | |
timeout=90.0 | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise | |
await httpx.AsyncClient().aclose() | |
continue | |
async def read_root(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def generate_response(query: QueryRequest): | |
try: | |
async with httpx.AsyncClient(timeout=90.0) as client: | |
result = await generate_with_retries( | |
client, | |
{ | |
"model": "llama3.2:1b-papalia", | |
"prompt": query.prompt, | |
"temperature": query.temperature, | |
"max_tokens": query.max_tokens, | |
"stream": False | |
} | |
) | |
return {"response": result.get("response", ""), "model": "llama3.2:1b-papalia"} | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
try: | |
async with httpx.AsyncClient(timeout=5.0) as client: | |
await client.get("http://localhost:11434") | |
return {"status": "healthy"} | |
except Exception as e: | |
return {"status": "unhealthy", "error": str(e)} | |