abhisheksan commited on
Commit
6feef58
1 Parent(s): af80dec

Enhance ModelManager and PoetryGenerationService with optimizations and new features

Browse files

- Update model name in ModelManager for improved performance
- Integrate optimized module for memory efficiency
- Add caching for Hugging Face token retrieval
- Modify generate_poem method to include truncation and max_length adjustments
- Introduce generate_poems method for batch poem generation using threading

Files changed (1) hide show
  1. app/services/poetry_generation.py +38 -6
app/services/poetry_generation.py CHANGED
@@ -1,6 +1,18 @@
1
  from typing import Optional
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class ModelManager:
6
  _instance = None
@@ -13,7 +25,7 @@ class ModelManager:
13
 
14
  def __init__(self):
15
  if not ModelManager._initialized:
16
- model_name = "meta-llama/Llama-3.2-3B-Instruct"
17
 
18
  # Initialize tokenizer and model
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -24,8 +36,10 @@ class ModelManager:
24
  device_map="auto"
25
  )
26
 
27
- # Set model to evaluation mode
 
28
  self.model.eval()
 
29
  ModelManager._initialized = True
30
 
31
  def __del__(self):
@@ -36,24 +50,37 @@ class ModelManager:
36
  except:
37
  pass
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  class PoetryGenerationService:
40
  def __init__(self):
41
  # Get model manager instance
42
  model_manager = ModelManager()
43
  self.model = model_manager.model
44
  self.tokenizer = model_manager.tokenizer
 
45
 
46
- async def generate_poem(
 
47
  self,
48
  prompt: str,
49
  temperature: Optional[float] = 0.7,
50
  top_p: Optional[float] = 0.9,
51
  top_k: Optional[int] = 50,
52
- max_length: Optional[int] = 200,
53
  repetition_penalty: Optional[float] = 1.1
54
  ) -> str:
55
  try:
56
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
57
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
58
 
59
  with torch.no_grad():
@@ -77,4 +104,9 @@ class PoetryGenerationService:
77
  )
78
 
79
  except Exception as e:
80
- raise Exception(f"Error generating poem: {str(e)}")
 
 
 
 
 
 
1
  from typing import Optional
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import os
5
+ import logging
6
+ from functools import lru_cache
7
+ import concurrent.futures
8
+ from torch.compile import (
9
+ InputsAreOptimized,
10
+ optimized_module,
11
+ optimized_static_function,
12
+ )
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
 
17
  class ModelManager:
18
  _instance = None
 
25
 
26
  def __init__(self):
27
  if not ModelManager._initialized:
28
+ model_name = "meta-llama/Llama-2B-Instruct"
29
 
30
  # Initialize tokenizer and model
31
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
36
  device_map="auto"
37
  )
38
 
39
+ # Set model to evaluation mode and move to GPU
40
+ self.model = self.model.to(self.model.device)
41
  self.model.eval()
42
+ self.model = optimized_module(self.model, memory_efficient=True)
43
  ModelManager._initialized = True
44
 
45
  def __del__(self):
 
50
  except:
51
  pass
52
 
53
+ @lru_cache(maxsize=1)
54
+ def get_hf_token() -> str:
55
+ """Get Hugging Face token from environment variables."""
56
+ token = os.getenv("HF_TOKEN")
57
+ if not token:
58
+ raise EnvironmentError(
59
+ "HF_TOKEN environment variable not found. "
60
+ "Please set your Hugging Face access token."
61
+ )
62
+ return token
63
+
64
  class PoetryGenerationService:
65
  def __init__(self):
66
  # Get model manager instance
67
  model_manager = ModelManager()
68
  self.model = model_manager.model
69
  self.tokenizer = model_manager.tokenizer
70
+ self.cache = {}
71
 
72
+ @optimized_static_function(InputsAreOptimized())
73
+ def generate_poem(
74
  self,
75
  prompt: str,
76
  temperature: Optional[float] = 0.7,
77
  top_p: Optional[float] = 0.9,
78
  top_k: Optional[int] = 50,
79
+ max_length: Optional[int] = 100,
80
  repetition_penalty: Optional[float] = 1.1
81
  ) -> str:
82
  try:
83
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
84
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
85
 
86
  with torch.no_grad():
 
104
  )
105
 
106
  except Exception as e:
107
+ raise Exception(f"Error generating poem: {str(e)}")
108
+
109
+ def generate_poems(self, prompts: list[str]) -> list[str]:
110
+ with concurrent.futures.ThreadPoolExecutor() as executor:
111
+ poems = list(executor.map(self.generate_poem, prompts))
112
+ return poems