import gradio as gr from fastapi import FastAPI from pydantic import BaseModel, Field from typing import Optional # Initialize FastAPI app app = FastAPI() # Load the model once at startup model = gr.load("models/meta-llama/Llama-3.2-3B-Instruct") class PoemRequest(BaseModel): prompt: str = Field(..., description="The prompt for poem generation") temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0, description="Controls randomness in generation") top_p: Optional[float] = Field(0.9, ge=0.1, le=1.0, description="Nucleus sampling parameter") top_k: Optional[int] = Field(50, ge=1, le=100, description="Top-k sampling parameter") max_length: Optional[int] = Field(200, ge=50, le=500, description="Maximum length of generated text") repetition_penalty: Optional[float] = Field(1.1, ge=1.0, le=2.0, description="Penalty for repetition") class PoemResponse(BaseModel): poem: str parameters_used: dict @app.post("/generate_poem") async def generate_poem(request: PoemRequest) -> PoemResponse: """ Generate a poem based on the provided prompt and parameters. Returns: PoemResponse: Contains the generated poem and the parameters used """ try: # Prepare generation parameters generation_config = { "temperature": request.temperature, "top_p": request.top_p, "top_k": request.top_k, "max_length": request.max_length, "repetition_penalty": request.repetition_penalty, } # Generate the poem response = model( request.prompt, **generation_config ) return PoemResponse( poem=response, parameters_used=generation_config ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)