# poetry_generation.py import asyncio from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import torch import os import logging from functools import lru_cache import concurrent.futures logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Use a smaller model model_name = "facebook/opt-125m" # Much smaller than Llama class ModelManager: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._initialized = False return cls._instance def __init__(self): if not ModelManager._initialized: # Initialize quantization config quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) # Initialize tokenizer and model with quantization self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with optimizations self.model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16 ) # Enable model optimizations self.model.config.use_cache = True # Set model to evaluation mode self.model.eval() # Move model to GPU if available if torch.cuda.is_available(): self.model = self.model.cuda() torch.backends.cudnn.benchmark = True ModelManager._initialized = True def __del__(self): try: del self.model del self.tokenizer torch.cuda.empty_cache() except Exception as e: logger.error(f"Error during cleanup: {str(e)}") class PoetryGenerationService: def __init__(self): model_manager = ModelManager() self.model = model_manager.model self.tokenizer = model_manager.tokenizer # Pre-compile common prompt templates self.prompt_template = "Write a short poem about {}\n" def preload_models(self): """Preload the models during application startup""" try: _ = ModelManager() # Warmup generation self.generate_poem("warmup") logger.info("Models preloaded successfully") return True except Exception as e: logger.error(f"Error preloading models: {str(e)}") raise Exception("Failed to preload models") from e def generate_poem( self, prompt: str, temperature: Optional[float] = 0.7, top_p: Optional[float] = 0.9, top_k: Optional[int] = 50, max_length: Optional[int] = 150, repetition_penalty: Optional[float] = 1.1 ) -> str: try: # Format prompt using template formatted_prompt = self.prompt_template.format(prompt) # Optimize input processing inputs = self.tokenizer( formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=64 # Reduced from 128 ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.inference_mode(): # Faster than torch.no_grad() outputs = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_length=max_length, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, num_beams=1, # Disable beam search for speed early_stopping=True ) return self.tokenizer.decode( outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) except Exception as e: logger.error(f"Error generating poem: {str(e)}") return f"Error generating poem: {str(e)}" async def generate_poems_async(self, prompts: List[str]) -> List[str]: loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as executor: poems = await asyncio.gather( *[loop.run_in_executor(executor, self.generate_poem, prompt) for prompt in prompts] ) return poems