poetica / main.py
abhisheksan's picture
Update model configuration and fix tokenizer files; remove outdated model binary
8702989
import os
from typing import Optional, Dict, Any, List
from fastapi import FastAPI, HTTPException, status, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import logging
import sys
from pydantic import BaseModel, Field, validator
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from contextlib import asynccontextmanager
import asyncio
from functools import lru_cache
import numpy as np
from datetime import datetime
import re
# Constants
BASE_MODEL_DIR = "./models/"
MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4
CACHE_SIZE = 1024
MODEL_CONFIG = GPT2Config(
n_positions=400,
n_ctx=400,
n_embd=384,
n_layer=6,
n_head=6,
vocab_size=50257,
bos_token_id=50256,
eos_token_id=50256,
use_cache=True,
)
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=500)
max_length: Optional[int] = Field(default=100, ge=10, le=500)
temperature: float = Field(default=0.9, ge=0.1, le=2.0)
top_k: int = Field(default=50, ge=1, le=100)
top_p: float = Field(default=0.95, ge=0.1, le=1.0)
repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0)
style: Optional[str] = Field(default="free_verse",
description="Poetry style: free_verse, haiku, sonnet")
@validator('prompt')
def validate_prompt(cls, v):
v = ' '.join(v.split())
return v
class PoemFormatter:
"""Handles poem formatting and processing"""
@staticmethod
def format_free_verse(text: str) -> List[str]:
lines = re.split(r'[.!?]+|\n+', text)
lines = [line.strip() for line in lines if line.strip()]
formatted_lines = []
for line in lines:
if len(line) > 40:
parts = line.split(',')
formatted_lines.extend(part.strip() for part in parts if part.strip())
else:
formatted_lines.append(line)
return formatted_lines
@staticmethod
def format_haiku(text: str) -> List[str]:
words = text.split()
lines = []
current_line = []
syllable_count = 0
for word in words:
syllables = len(re.findall(r'[aeiou]+', word.lower()))
if syllable_count + syllables <= 5 and len(lines) == 0:
current_line.append(word)
syllable_count += syllables
elif syllable_count + syllables <= 7 and len(lines) == 1:
current_line.append(word)
syllable_count += syllables
elif syllable_count + syllables <= 5 and len(lines) == 2:
current_line.append(word)
syllable_count += syllables
else:
if current_line:
lines.append(' '.join(current_line))
current_line = [word]
syllable_count = syllables
if len(lines) == 3:
break
if current_line and len(lines) < 3:
lines.append(' '.join(current_line))
return lines[:3]
@staticmethod
def format_sonnet(text: str) -> List[str]:
words = text.split()
lines = []
current_line = []
target_line_length = 10
for word in words:
current_line.append(word)
if len(current_line) >= target_line_length:
lines.append(' '.join(current_line))
current_line = []
if len(lines) >= 14:
break
if current_line and len(lines) < 14:
lines.append(' '.join(current_line))
return lines[:14]
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self._lock = asyncio.Lock()
self.request_count = 0
self.last_cleanup = datetime.now()
self.poem_formatter = PoemFormatter()
async def initialize(self) -> bool:
try:
self._setup_logging()
logger.info(f"Initializing model on device: {DEVICE}")
self.tokenizer = await self._load_tokenizer()
await self._load_and_optimize_model()
logger.info("Model and tokenizer loaded successfully")
return True
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
logger.exception("Detailed traceback:")
return False
@staticmethod
def _setup_logging():
global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handlers = [logging.StreamHandler(sys.stdout)]
try:
log_dir = os.path.join(os.getcwd(), 'logs')
os.makedirs(log_dir, exist_ok=True)
handlers.append(logging.FileHandler(
os.path.join(log_dir, f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log')
))
except Exception as e:
print(f"Warning: Could not create log file: {e}")
for handler in handlers:
handler.setFormatter(formatter)
logger.addHandler(handler)
@lru_cache(maxsize=CACHE_SIZE)
async def _load_tokenizer(self):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
async def _load_and_optimize_model(self):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
self.model = GPT2LMHeadModel(MODEL_CONFIG)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
self.model.load_state_dict(state_dict, strict=False)
self.model.to(DEVICE)
self.model.eval()
if DEVICE.type == 'cuda':
torch.backends.cudnn.benchmark = True
self.model = torch.jit.script(self.model)
dummy_input = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
with torch.no_grad():
self.model(dummy_input)
@torch.no_grad()
async def generate(self, request: GenerateRequest) -> Dict[str, Any]:
async with self._lock:
try:
self.request_count += 1
await self._check_cleanup()
inputs = await self._prepare_inputs(request.prompt)
outputs = await self._generate_optimized(inputs, request)
return await self._process_outputs(outputs, request)
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
async def _prepare_inputs(self, prompt: str):
poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:"
tokens = self.tokenizer.encode(poetry_prompt, return_tensors='pt')
return tokens.to(DEVICE)
async def _generate_optimized(self, inputs, request: GenerateRequest):
attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=DEVICE)
style_params = {
"haiku": {"max_length": 50, "repetition_penalty": 1.3},
"sonnet": {"max_length": 200, "repetition_penalty": 1.2},
"free_verse": {"max_length": request.max_length, "repetition_penalty": request.repetition_penalty}
}
params = style_params.get(request.style, style_params["free_verse"])
return self.model.generate(
inputs,
attention_mask=attention_mask,
max_length=params["max_length"],
num_return_sequences=1,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=params["repetition_penalty"],
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
use_cache=True,
no_repeat_ngram_size=3,
early_stopping=True,
bad_words_ids=[[self.tokenizer.encode(word)[0]] for word in
['http', 'www', 'com', ':', '/', '#']],
min_length=20,
)
async def _process_outputs(self, outputs, request: GenerateRequest):
raw_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:"
poem_text = raw_text.replace(prompt_pattern, '').strip()
if request.style == "haiku":
formatted_lines = PoemFormatter.format_haiku(poem_text)
elif request.style == "sonnet":
formatted_lines = PoemFormatter.format_sonnet(poem_text)
else:
formatted_lines = PoemFormatter.format_free_verse(poem_text)
return {
"poem": {
"title": self._generate_title(poem_text),
"lines": formatted_lines,
"style": request.style
},
"original_prompt": request.prompt,
"parameters": {
"max_length": request.max_length,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty
},
"metadata": {
"device": DEVICE.type,
"model_type": "GPT2",
"timestamp": datetime.now().isoformat()
}
}
def _generate_title(self, poem_text: str) -> str:
words = poem_text.split()[:6]
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'}
key_words = [word for word in words if word.lower() not in stop_words]
if key_words:
title = ' '.join(key_words[:3]).capitalize()
return title
return "Untitled"
async def _check_cleanup(self):
if self.request_count % 100 == 0:
if DEVICE.type == 'cuda':
torch.cuda.empty_cache()
self.last_cleanup = datetime.now()
@asynccontextmanager
async def lifespan(app: FastAPI):
if not await model_manager.initialize():
logger.error("Failed to initialize model manager")
yield
if model_manager.model is not None:
del model_manager.model
if model_manager.tokenizer is not None:
del model_manager.tokenizer
if DEVICE.type == 'cuda':
torch.cuda.empty_cache()
app = FastAPI(
title="Poetry Generation API",
description="Optimized API for generating poetry using GPT-2",
version="2.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_manager = ModelManager()
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": model_manager.model is not None,
"tokenizer_loaded": model_manager.tokenizer is not None,
"device": DEVICE.type,
"request_count": model_manager.request_count,
"last_cleanup": model_manager.last_cleanup.isoformat(),
"system_info": {
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
}
}
@app.post("/generate")
async def generate_text(
request: GenerateRequest,
background_tasks: BackgroundTasks
):
try:
result = await model_manager.generate(request)
if model_manager.request_count % 100 == 0:
background_tasks.add_task(torch.cuda.empty_cache)
return JSONResponse(
content=result,
status_code=status.HTTP_200_OK
)
except Exception as e:
logger.error(f"Error in generate_text: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)