Spaces:
Runtime error
Runtime error
# app.py | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer | |
from peft import AutoPeftModelForCausalLM | |
import torch | |
from typing import Optional | |
app = FastAPI(title="Gemma Script Generator API") | |
# Load model and tokenizer | |
MODEL_NAME = "Sidharthan/gemma2_scripter" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True | |
) | |
model = AutoPeftModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", # Will use CPU if GPU not available | |
trust_remote_code=True, | |
#load_in_4bit=True | |
) | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
class GenerationRequest(BaseModel): | |
message: str | |
max_length: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
top_k: Optional[int] = 50 | |
repetition_penalty: Optional[float] = 1.2 | |
class GenerationResponse(BaseModel): | |
generated_text: str | |
async def generate_script(request: GenerationRequest): | |
try: | |
# Format prompt | |
prompt = request.message | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
# Generate | |
outputs = model.generate( | |
**inputs, | |
max_length=request.max_length, | |
do_sample=True, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
repetition_penalty=request.repetition_penalty, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode output | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return GenerationResponse(generated_text=generated_text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |