# app.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer from peft import AutoPeftModelForCausalLM import torch from typing import Optional import os os.environ['HF_HOME'] = '/app/cache' app = FastAPI(title="Gemma Script Generator API") hf_token = os.getenv('HF_TOKEN') # Load model and tokenizer MODEL_NAME = "Sidharthan/gemma2_scripter" try: tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, use_auth_token = hf_token ) model = AutoPeftModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", # Will use CPU if GPU not available trust_remote_code=True, cache_dir = '/app/cache' #load_in_4bit=True ) except Exception as e: print(f"Error loading model: {str(e)}") raise class GenerationRequest(BaseModel): message: str max_length: Optional[int] = 512 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 top_k: Optional[int] = 50 repetition_penalty: Optional[float] = 1.2 class GenerationResponse(BaseModel): generated_text: str @app.post("/generate", response_model=GenerationResponse) async def generate_script(request: GenerationRequest): try: # Format prompt prompt = request.message # Tokenize input inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Generate outputs = model.generate( **inputs, max_length=request.max_length, do_sample=True, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return GenerationResponse(generated_text=generated_text) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)