abhisheksan commited on
Commit
6dbb459
1 Parent(s): abc61cb

Optimize model loading and error handling in PoetryGenerationService; implement async poem generation and enhance application startup process

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ __pycache__/main.cpython-312.pyc
__pycache__/main.cpython-312.pyc DELETED
Binary file (1.88 kB)
 
app/services/poetry_generation.py CHANGED
@@ -1,5 +1,7 @@
 
 
1
  from typing import Optional, List
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  import logging
@@ -9,7 +11,8 @@ import concurrent.futures
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- model_name = "meta-llama/Llama-3.2-1B-Instruct"
 
13
 
14
  class ModelManager:
15
  _instance = None
@@ -22,16 +25,35 @@ class ModelManager:
22
 
23
  def __init__(self):
24
  if not ModelManager._initialized:
25
- # Initialize tokenizer and model
 
 
 
 
 
 
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
  model_name,
30
- torch_dtype=torch.float16,
31
- device_map="auto"
 
32
  )
33
- # Set model to evaluation mode and move to GPU
 
 
 
 
34
  self.model.eval()
 
 
 
 
 
 
35
  ModelManager._initialized = True
36
 
37
  def __del__(self):
@@ -42,30 +64,23 @@ class ModelManager:
42
  except Exception as e:
43
  logger.error(f"Error during cleanup: {str(e)}")
44
 
45
- @lru_cache(maxsize=1)
46
- def get_hf_token() -> str:
47
- """Get Hugging Face token from environment variables."""
48
- token = os.getenv("HF_TOKEN")
49
- if not token:
50
- raise EnvironmentError(
51
- "HF_TOKEN environment variable not found. "
52
- "Please set your Hugging Face access token."
53
- )
54
- return token
55
-
56
  class PoetryGenerationService:
57
  def __init__(self):
58
- # Get model manager instance
59
  model_manager = ModelManager()
60
  self.model = model_manager.model
61
  self.tokenizer = model_manager.tokenizer
 
 
 
62
 
63
  def preload_models(self):
64
  """Preload the models during application startup"""
65
  try:
66
- _ = ModelManager() # Ensure ModelManager singleton is initialized
 
 
67
  logger.info("Models preloaded successfully")
68
- return True # Return a meaningful value
69
  except Exception as e:
70
  logger.error(f"Error preloading models: {str(e)}")
71
  raise Exception("Failed to preload models") from e
@@ -76,14 +91,25 @@ class PoetryGenerationService:
76
  temperature: Optional[float] = 0.7,
77
  top_p: Optional[float] = 0.9,
78
  top_k: Optional[int] = 50,
79
- max_length: Optional[int] = 100,
80
  repetition_penalty: Optional[float] = 1.1
81
  ) -> str:
82
  try:
83
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
 
 
 
 
 
 
 
 
 
 
 
84
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
85
 
86
- with torch.no_grad():
87
  outputs = self.model.generate(
88
  inputs["input_ids"],
89
  attention_mask=inputs["attention_mask"],
@@ -95,6 +121,8 @@ class PoetryGenerationService:
95
  repetition_penalty=repetition_penalty,
96
  pad_token_id=self.tokenizer.eos_token_id,
97
  eos_token_id=self.tokenizer.eos_token_id,
 
 
98
  )
99
 
100
  return self.tokenizer.decode(
@@ -104,9 +132,14 @@ class PoetryGenerationService:
104
  )
105
 
106
  except Exception as e:
107
- raise Exception(f"Error generating poem: {str(e)}")
 
108
 
109
- def generate_poems(self, prompts: List[str]) -> List[str]:
 
110
  with concurrent.futures.ThreadPoolExecutor() as executor:
111
- poems = list(executor.map(self.generate_poem, prompts))
112
- return poems
 
 
 
 
1
+ # poetry_generation.py
2
+ import asyncio
3
  from typing import Optional, List
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
  import torch
6
  import os
7
  import logging
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Use a smaller model
15
+ model_name = "facebook/opt-125m" # Much smaller than Llama
16
 
17
  class ModelManager:
18
  _instance = None
 
25
 
26
  def __init__(self):
27
  if not ModelManager._initialized:
28
+ # Initialize quantization config
29
+ quantization_config = BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ bnb_4bit_compute_dtype=torch.float16
32
+ )
33
+
34
+ # Initialize tokenizer and model with quantization
35
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
36
  self.tokenizer.pad_token = self.tokenizer.eos_token
37
+
38
+ # Load model with optimizations
39
  self.model = AutoModelForCausalLM.from_pretrained(
40
  model_name,
41
+ quantization_config=quantization_config,
42
+ device_map="auto",
43
+ torch_dtype=torch.float16
44
  )
45
+
46
+ # Enable model optimizations
47
+ self.model.config.use_cache = True
48
+
49
+ # Set model to evaluation mode
50
  self.model.eval()
51
+
52
+ # Move model to GPU if available
53
+ if torch.cuda.is_available():
54
+ self.model = self.model.cuda()
55
+ torch.backends.cudnn.benchmark = True
56
+
57
  ModelManager._initialized = True
58
 
59
  def __del__(self):
 
64
  except Exception as e:
65
  logger.error(f"Error during cleanup: {str(e)}")
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  class PoetryGenerationService:
68
  def __init__(self):
 
69
  model_manager = ModelManager()
70
  self.model = model_manager.model
71
  self.tokenizer = model_manager.tokenizer
72
+
73
+ # Pre-compile common prompt templates
74
+ self.prompt_template = "Write a short poem about {}\n"
75
 
76
  def preload_models(self):
77
  """Preload the models during application startup"""
78
  try:
79
+ _ = ModelManager()
80
+ # Warmup generation
81
+ self.generate_poem("warmup")
82
  logger.info("Models preloaded successfully")
83
+ return True
84
  except Exception as e:
85
  logger.error(f"Error preloading models: {str(e)}")
86
  raise Exception("Failed to preload models") from e
 
91
  temperature: Optional[float] = 0.7,
92
  top_p: Optional[float] = 0.9,
93
  top_k: Optional[int] = 50,
94
+ max_length: Optional[int] = 150,
95
  repetition_penalty: Optional[float] = 1.1
96
  ) -> str:
97
  try:
98
+ # Format prompt using template
99
+ formatted_prompt = self.prompt_template.format(prompt)
100
+
101
+ # Optimize input processing
102
+ inputs = self.tokenizer(
103
+ formatted_prompt,
104
+ return_tensors="pt",
105
+ padding=True,
106
+ truncation=True,
107
+ max_length=64 # Reduced from 128
108
+ )
109
+
110
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
111
 
112
+ with torch.inference_mode(): # Faster than torch.no_grad()
113
  outputs = self.model.generate(
114
  inputs["input_ids"],
115
  attention_mask=inputs["attention_mask"],
 
121
  repetition_penalty=repetition_penalty,
122
  pad_token_id=self.tokenizer.eos_token_id,
123
  eos_token_id=self.tokenizer.eos_token_id,
124
+ num_beams=1, # Disable beam search for speed
125
+ early_stopping=True
126
  )
127
 
128
  return self.tokenizer.decode(
 
132
  )
133
 
134
  except Exception as e:
135
+ logger.error(f"Error generating poem: {str(e)}")
136
+ return f"Error generating poem: {str(e)}"
137
 
138
+ async def generate_poems_async(self, prompts: List[str]) -> List[str]:
139
+ loop = asyncio.get_event_loop()
140
  with concurrent.futures.ThreadPoolExecutor() as executor:
141
+ poems = await asyncio.gather(
142
+ *[loop.run_in_executor(executor, self.generate_poem, prompt)
143
+ for prompt in prompts]
144
+ )
145
+ return poems
main.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import asyncio
2
  from contextlib import asynccontextmanager
3
- from fastapi import FastAPI
4
  from app.api.endpoints.poetry import router as poetry_router
5
  import os
6
  import logging
@@ -11,10 +12,11 @@ from huggingface_hub import login
11
  from functools import lru_cache
12
  from app.services.poetry_generation import PoetryGenerationService
13
 
14
- # Configure logging once at module level
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
18
  @lru_cache()
19
  def get_hf_token() -> str:
20
  """Get Hugging Face token from environment variables."""
@@ -35,26 +37,29 @@ def init_huggingface():
35
  except Exception as e:
36
  logger.error(f"Failed to login to Hugging Face: {str(e)}")
37
  raise
38
-
39
  @asynccontextmanager
40
  async def lifespan(app: FastAPI):
41
- # Initialize Hugging Face authentication first
 
 
42
  init_huggingface()
43
-
44
  # Initialize poetry service and preload models
45
  poetry_service = PoetryGenerationService()
46
-
47
  try:
48
- preload_result = poetry_service.preload_models()
49
- if asyncio.iscoroutine(preload_result):
50
- await preload_result
51
- else:
52
- preload_result # Call directly if synchronous
 
 
53
  except Exception as e:
54
- logger.error(f"Error preloading models: {str(e)}")
55
  raise
56
-
57
- yield # Continue to application startup
58
 
59
  app = FastAPI(lifespan=lifespan)
60
  app.include_router(poetry_router, prefix="/api/v1/poetry")
@@ -63,14 +68,16 @@ app.include_router(poetry_router, prefix="/api/v1/poetry")
63
  async def lifecheck():
64
  return Response("OK", media_type="text/plain")
65
 
66
- def get_port() -> int:
67
- return int(os.getenv("PORT", "8000"))
68
-
69
  if __name__ == "__main__":
70
  import uvicorn
71
 
72
- port = get_port()
73
- app.mount("/static", StaticFiles(directory="static"), name="static")
74
 
75
- logger.info(f"Starting FastAPI server on port {port}")
76
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
1
+ # main.py
2
  import asyncio
3
  from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, BackgroundTasks
5
  from app.api.endpoints.poetry import router as poetry_router
6
  import os
7
  import logging
 
12
  from functools import lru_cache
13
  from app.services.poetry_generation import PoetryGenerationService
14
 
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Global poetry service instance
19
+ poetry_service = None
20
  @lru_cache()
21
  def get_hf_token() -> str:
22
  """Get Hugging Face token from environment variables."""
 
37
  except Exception as e:
38
  logger.error(f"Failed to login to Hugging Face: {str(e)}")
39
  raise
 
40
  @asynccontextmanager
41
  async def lifespan(app: FastAPI):
42
+ global poetry_service
43
+
44
+ # Initialize Hugging Face authentication
45
  init_huggingface()
46
+
47
  # Initialize poetry service and preload models
48
  poetry_service = PoetryGenerationService()
49
+
50
  try:
51
+ # Preload models in background
52
+ background_tasks = BackgroundTasks()
53
+ background_tasks.add_task(poetry_service.preload_models)
54
+
55
+ logger.info("Application startup complete")
56
+ yield
57
+
58
  except Exception as e:
59
+ logger.error(f"Error during startup: {str(e)}")
60
  raise
61
+ finally:
62
+ logger.info("Shutting down application")
63
 
64
  app = FastAPI(lifespan=lifespan)
65
  app.include_router(poetry_router, prefix="/api/v1/poetry")
 
68
  async def lifecheck():
69
  return Response("OK", media_type="text/plain")
70
 
 
 
 
71
  if __name__ == "__main__":
72
  import uvicorn
73
 
74
+ port = int(os.getenv("PORT", "8000"))
 
75
 
76
+ # Configure uvicorn with optimized settings
77
+ uvicorn.run(
78
+ app,
79
+ host="0.0.0.0",
80
+ port=port,
81
+ loop="uvloop", # Faster event loop implementation
82
+ http="httptools", # Faster HTTP protocol implementation
83
+ )