Spaces:
Running
Running
Optimize model loading and error handling in PoetryGenerationService; implement async poem generation and enhance application startup process
6dbb459
# poetry_generation.py | |
import asyncio | |
from typing import Optional, List | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import torch | |
import os | |
import logging | |
from functools import lru_cache | |
import concurrent.futures | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Use a smaller model | |
model_name = "facebook/opt-125m" # Much smaller than Llama | |
class ModelManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super().__new__(cls) | |
cls._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not ModelManager._initialized: | |
# Initialize quantization config | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Initialize tokenizer and model with quantization | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Load model with optimizations | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
# Enable model optimizations | |
self.model.config.use_cache = True | |
# Set model to evaluation mode | |
self.model.eval() | |
# Move model to GPU if available | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
torch.backends.cudnn.benchmark = True | |
ModelManager._initialized = True | |
def __del__(self): | |
try: | |
del self.model | |
del self.tokenizer | |
torch.cuda.empty_cache() | |
except Exception as e: | |
logger.error(f"Error during cleanup: {str(e)}") | |
class PoetryGenerationService: | |
def __init__(self): | |
model_manager = ModelManager() | |
self.model = model_manager.model | |
self.tokenizer = model_manager.tokenizer | |
# Pre-compile common prompt templates | |
self.prompt_template = "Write a short poem about {}\n" | |
def preload_models(self): | |
"""Preload the models during application startup""" | |
try: | |
_ = ModelManager() | |
# Warmup generation | |
self.generate_poem("warmup") | |
logger.info("Models preloaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error preloading models: {str(e)}") | |
raise Exception("Failed to preload models") from e | |
def generate_poem( | |
self, | |
prompt: str, | |
temperature: Optional[float] = 0.7, | |
top_p: Optional[float] = 0.9, | |
top_k: Optional[int] = 50, | |
max_length: Optional[int] = 150, | |
repetition_penalty: Optional[float] = 1.1 | |
) -> str: | |
try: | |
# Format prompt using template | |
formatted_prompt = self.prompt_template.format(prompt) | |
# Optimize input processing | |
inputs = self.tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=64 # Reduced from 128 | |
) | |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
with torch.inference_mode(): # Faster than torch.no_grad() | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_length=max_length, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
num_beams=1, # Disable beam search for speed | |
early_stopping=True | |
) | |
return self.tokenizer.decode( | |
outputs[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
except Exception as e: | |
logger.error(f"Error generating poem: {str(e)}") | |
return f"Error generating poem: {str(e)}" | |
async def generate_poems_async(self, prompts: List[str]) -> List[str]: | |
loop = asyncio.get_event_loop() | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
poems = await asyncio.gather( | |
*[loop.run_in_executor(executor, self.generate_poem, prompt) | |
for prompt in prompts] | |
) | |
return poems |