poetica / app /services /poetry_generation.py
abhisheksan's picture
Optimize model loading and error handling in PoetryGenerationService; implement async poem generation and enhance application startup process
6dbb459
raw
history blame
5.07 kB
# poetry_generation.py
import asyncio
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import os
import logging
from functools import lru_cache
import concurrent.futures
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Use a smaller model
model_name = "facebook/opt-125m" # Much smaller than Llama
class ModelManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._initialized = False
return cls._instance
def __init__(self):
if not ModelManager._initialized:
# Initialize quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
# Initialize tokenizer and model with quantization
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16
)
# Enable model optimizations
self.model.config.use_cache = True
# Set model to evaluation mode
self.model.eval()
# Move model to GPU if available
if torch.cuda.is_available():
self.model = self.model.cuda()
torch.backends.cudnn.benchmark = True
ModelManager._initialized = True
def __del__(self):
try:
del self.model
del self.tokenizer
torch.cuda.empty_cache()
except Exception as e:
logger.error(f"Error during cleanup: {str(e)}")
class PoetryGenerationService:
def __init__(self):
model_manager = ModelManager()
self.model = model_manager.model
self.tokenizer = model_manager.tokenizer
# Pre-compile common prompt templates
self.prompt_template = "Write a short poem about {}\n"
def preload_models(self):
"""Preload the models during application startup"""
try:
_ = ModelManager()
# Warmup generation
self.generate_poem("warmup")
logger.info("Models preloaded successfully")
return True
except Exception as e:
logger.error(f"Error preloading models: {str(e)}")
raise Exception("Failed to preload models") from e
def generate_poem(
self,
prompt: str,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 0.9,
top_k: Optional[int] = 50,
max_length: Optional[int] = 150,
repetition_penalty: Optional[float] = 1.1
) -> str:
try:
# Format prompt using template
formatted_prompt = self.prompt_template.format(prompt)
# Optimize input processing
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64 # Reduced from 128
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.inference_mode(): # Faster than torch.no_grad()
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=max_length,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_beams=1, # Disable beam search for speed
early_stopping=True
)
return self.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
except Exception as e:
logger.error(f"Error generating poem: {str(e)}")
return f"Error generating poem: {str(e)}"
async def generate_poems_async(self, prompts: List[str]) -> List[str]:
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
poems = await asyncio.gather(
*[loop.run_in_executor(executor, self.generate_poem, prompt)
for prompt in prompts]
)
return poems