Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional, Dict, Any, List | |
from fastapi import FastAPI, HTTPException, status, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
import logging | |
import sys | |
from pydantic import BaseModel, Field, validator | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config | |
from contextlib import asynccontextmanager | |
import asyncio | |
from functools import lru_cache | |
import numpy as np | |
from datetime import datetime | |
import re | |
# Constants | |
BASE_MODEL_DIR = "./models/" | |
MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth") | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
BATCH_SIZE = 4 | |
CACHE_SIZE = 1024 | |
MODEL_CONFIG = GPT2Config( | |
n_positions=400, | |
n_ctx=400, | |
n_embd=384, | |
n_layer=6, | |
n_head=6, | |
vocab_size=50257, | |
bos_token_id=50256, | |
eos_token_id=50256, | |
use_cache=True, | |
) | |
class GenerateRequest(BaseModel): | |
prompt: str = Field(..., min_length=1, max_length=500) | |
max_length: Optional[int] = Field(default=100, ge=10, le=500) | |
temperature: float = Field(default=0.9, ge=0.1, le=2.0) | |
top_k: int = Field(default=50, ge=1, le=100) | |
top_p: float = Field(default=0.95, ge=0.1, le=1.0) | |
repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0) | |
style: Optional[str] = Field(default="free_verse", | |
description="Poetry style: free_verse, haiku, sonnet") | |
def validate_prompt(cls, v): | |
v = ' '.join(v.split()) | |
return v | |
class PoemFormatter: | |
"""Handles poem formatting and processing""" | |
def format_free_verse(text: str) -> List[str]: | |
lines = re.split(r'[.!?]+|\n+', text) | |
lines = [line.strip() for line in lines if line.strip()] | |
formatted_lines = [] | |
for line in lines: | |
if len(line) > 40: | |
parts = line.split(',') | |
formatted_lines.extend(part.strip() for part in parts if part.strip()) | |
else: | |
formatted_lines.append(line) | |
return formatted_lines | |
def format_haiku(text: str) -> List[str]: | |
words = text.split() | |
lines = [] | |
current_line = [] | |
syllable_count = 0 | |
for word in words: | |
syllables = len(re.findall(r'[aeiou]+', word.lower())) | |
if syllable_count + syllables <= 5 and len(lines) == 0: | |
current_line.append(word) | |
syllable_count += syllables | |
elif syllable_count + syllables <= 7 and len(lines) == 1: | |
current_line.append(word) | |
syllable_count += syllables | |
elif syllable_count + syllables <= 5 and len(lines) == 2: | |
current_line.append(word) | |
syllable_count += syllables | |
else: | |
if current_line: | |
lines.append(' '.join(current_line)) | |
current_line = [word] | |
syllable_count = syllables | |
if len(lines) == 3: | |
break | |
if current_line and len(lines) < 3: | |
lines.append(' '.join(current_line)) | |
return lines[:3] | |
def format_sonnet(text: str) -> List[str]: | |
words = text.split() | |
lines = [] | |
current_line = [] | |
target_line_length = 10 | |
for word in words: | |
current_line.append(word) | |
if len(current_line) >= target_line_length: | |
lines.append(' '.join(current_line)) | |
current_line = [] | |
if len(lines) >= 14: | |
break | |
if current_line and len(lines) < 14: | |
lines.append(' '.join(current_line)) | |
return lines[:14] | |
class ModelManager: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self._lock = asyncio.Lock() | |
self.request_count = 0 | |
self.last_cleanup = datetime.now() | |
self.poem_formatter = PoemFormatter() | |
async def initialize(self) -> bool: | |
try: | |
self._setup_logging() | |
logger.info(f"Initializing model on device: {DEVICE}") | |
self.tokenizer = await self._load_tokenizer() | |
await self._load_and_optimize_model() | |
logger.info("Model and tokenizer loaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
logger.exception("Detailed traceback:") | |
return False | |
def _setup_logging(): | |
global logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
handlers = [logging.StreamHandler(sys.stdout)] | |
try: | |
log_dir = os.path.join(os.getcwd(), 'logs') | |
os.makedirs(log_dir, exist_ok=True) | |
handlers.append(logging.FileHandler( | |
os.path.join(log_dir, f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log') | |
)) | |
except Exception as e: | |
print(f"Warning: Could not create log file: {e}") | |
for handler in handlers: | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
async def _load_tokenizer(self): | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
return tokenizer | |
async def _load_and_optimize_model(self): | |
if not os.path.exists(MODEL_PATH): | |
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") | |
self.model = GPT2LMHeadModel(MODEL_CONFIG) | |
state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
self.model.load_state_dict(state_dict, strict=False) | |
self.model.to(DEVICE) | |
self.model.eval() | |
if DEVICE.type == 'cuda': | |
torch.backends.cudnn.benchmark = True | |
self.model = torch.jit.script(self.model) | |
dummy_input = torch.zeros((1, 1), dtype=torch.long, device=DEVICE) | |
with torch.no_grad(): | |
self.model(dummy_input) | |
async def generate(self, request: GenerateRequest) -> Dict[str, Any]: | |
async with self._lock: | |
try: | |
self.request_count += 1 | |
await self._check_cleanup() | |
inputs = await self._prepare_inputs(request.prompt) | |
outputs = await self._generate_optimized(inputs, request) | |
return await self._process_outputs(outputs, request) | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=str(e) | |
) | |
async def _prepare_inputs(self, prompt: str): | |
poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:" | |
tokens = self.tokenizer.encode(poetry_prompt, return_tensors='pt') | |
return tokens.to(DEVICE) | |
async def _generate_optimized(self, inputs, request: GenerateRequest): | |
attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=DEVICE) | |
style_params = { | |
"haiku": {"max_length": 50, "repetition_penalty": 1.3}, | |
"sonnet": {"max_length": 200, "repetition_penalty": 1.2}, | |
"free_verse": {"max_length": request.max_length, "repetition_penalty": request.repetition_penalty} | |
} | |
params = style_params.get(request.style, style_params["free_verse"]) | |
return self.model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
max_length=params["max_length"], | |
num_return_sequences=1, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p, | |
repetition_penalty=params["repetition_penalty"], | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
use_cache=True, | |
no_repeat_ngram_size=3, | |
early_stopping=True, | |
bad_words_ids=[[self.tokenizer.encode(word)[0]] for word in | |
['http', 'www', 'com', ':', '/', '#']], | |
min_length=20, | |
) | |
async def _process_outputs(self, outputs, request: GenerateRequest): | |
raw_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:" | |
poem_text = raw_text.replace(prompt_pattern, '').strip() | |
if request.style == "haiku": | |
formatted_lines = PoemFormatter.format_haiku(poem_text) | |
elif request.style == "sonnet": | |
formatted_lines = PoemFormatter.format_sonnet(poem_text) | |
else: | |
formatted_lines = PoemFormatter.format_free_verse(poem_text) | |
return { | |
"poem": { | |
"title": self._generate_title(poem_text), | |
"lines": formatted_lines, | |
"style": request.style | |
}, | |
"original_prompt": request.prompt, | |
"parameters": { | |
"max_length": request.max_length, | |
"temperature": request.temperature, | |
"top_k": request.top_k, | |
"top_p": request.top_p, | |
"repetition_penalty": request.repetition_penalty | |
}, | |
"metadata": { | |
"device": DEVICE.type, | |
"model_type": "GPT2", | |
"timestamp": datetime.now().isoformat() | |
} | |
} | |
def _generate_title(self, poem_text: str) -> str: | |
words = poem_text.split()[:6] | |
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'} | |
key_words = [word for word in words if word.lower() not in stop_words] | |
if key_words: | |
title = ' '.join(key_words[:3]).capitalize() | |
return title | |
return "Untitled" | |
async def _check_cleanup(self): | |
if self.request_count % 100 == 0: | |
if DEVICE.type == 'cuda': | |
torch.cuda.empty_cache() | |
self.last_cleanup = datetime.now() | |
async def lifespan(app: FastAPI): | |
if not await model_manager.initialize(): | |
logger.error("Failed to initialize model manager") | |
yield | |
if model_manager.model is not None: | |
del model_manager.model | |
if model_manager.tokenizer is not None: | |
del model_manager.tokenizer | |
if DEVICE.type == 'cuda': | |
torch.cuda.empty_cache() | |
app = FastAPI( | |
title="Poetry Generation API", | |
description="Optimized API for generating poetry using GPT-2", | |
version="2.0.0", | |
lifespan=lifespan | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
model_manager = ModelManager() | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": model_manager.model is not None, | |
"tokenizer_loaded": model_manager.tokenizer is not None, | |
"device": DEVICE.type, | |
"request_count": model_manager.request_count, | |
"last_cleanup": model_manager.last_cleanup.isoformat(), | |
"system_info": { | |
"cuda_available": torch.cuda.is_available(), | |
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
} | |
} | |
async def generate_text( | |
request: GenerateRequest, | |
background_tasks: BackgroundTasks | |
): | |
try: | |
result = await model_manager.generate(request) | |
if model_manager.request_count % 100 == 0: | |
background_tasks.add_task(torch.cuda.empty_cache) | |
return JSONResponse( | |
content=result, | |
status_code=status.HTTP_200_OK | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_text: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=str(e) | |
) |