poetica / app /services /poetry_generation.py
abhisheksan's picture
Add initial project structure with FastAPI and poetry generation service
cee4b22
raw
history blame
2.5 kB
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class PoetryGenerationService:
def __init__(self):
model_name = "meta-llama/Llama-3.2-3B-Instruct" # Adjust model name as needed
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use float16 for efficiency
device_map="auto" # Automatically handle device placement
)
# Set model to evaluation mode
self.model.eval()
async def generate_poem(
self,
prompt: str,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 0.9,
top_k: Optional[int] = 50,
max_length: Optional[int] = 200,
repetition_penalty: Optional[float] = 1.1
) -> str:
try:
# Tokenize the input prompt
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
# Move input tensors to the same device as the model
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate text with the specified parameters
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=max_length,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode the generated text
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return generated_text
except Exception as e:
raise Exception(f"Error generating poem: {str(e)}")
def __del__(self):
# Clean up resources
try:
del self.model
del self.tokenizer
torch.cuda.empty_cache() # If using GPU
except:
pass