abhisheksan commited on
Commit
110ce35
·
1 Parent(s): 4bb8b78

Update model configuration and enhance initialization logic; adjust BASE_DIR for container, implement model download functionality, and improve health check response

Browse files
Files changed (2) hide show
  1. app/config.py +2 -5
  2. main.py +58 -24
app/config.py CHANGED
@@ -1,16 +1,13 @@
1
  import os
2
  from pathlib import Path
3
 
4
- # Base project directory
5
- BASE_DIR = Path(__file__).resolve().parent.parent
6
-
7
- # Model settings
8
  MODEL_DIR = BASE_DIR / "models"
9
  MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
10
  MODEL_PATH = MODEL_DIR / MODEL_NAME
11
 
12
  # Ensure model directory exists
13
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
14
-
15
  # Model download URL
16
  MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
 
1
  import os
2
  from pathlib import Path
3
 
4
+ # Base project directory (adjusted for container environment)
5
+ BASE_DIR = Path("/app")
 
 
6
  MODEL_DIR = BASE_DIR / "models"
7
  MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
8
  MODEL_PATH = MODEL_DIR / MODEL_NAME
9
 
10
  # Ensure model directory exists
11
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
 
12
  # Model download URL
13
  MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
main.py CHANGED
@@ -1,10 +1,10 @@
1
  from fastapi import FastAPI, HTTPException, status
2
- from pydantic import BaseModel, Field
3
  from typing import Optional, List
4
  from ctransformers import AutoModelForCausalLM
5
  import time
6
  import logging
7
- from app.config import MODEL_PATH
8
 
9
  # Configure logging
10
  logging.basicConfig(
@@ -49,25 +49,23 @@ class PoetryResponse(BaseModel):
49
  style: str
50
 
51
  class ModelInfo(BaseModel):
 
 
52
  status: str
53
- model_name: str
54
  model_path: str
 
55
  supported_styles: List[str]
56
  max_context_length: int
57
 
58
- @app.on_event("startup")
59
- async def startup_event():
60
- """Initialize the model during startup"""
61
- global model
62
- try:
63
- if not MODEL_PATH.exists():
64
- raise FileNotFoundError(
65
- f"Model file not found at {MODEL_PATH}. "
66
- "Please run download_model.py first."
67
- )
68
 
 
69
  logger.info(f"Loading model from {MODEL_PATH}")
70
- model = AutoModelForCausalLM.from_pretrained(
71
  str(MODEL_PATH.parent),
72
  model_file=MODEL_PATH.name,
73
  model_type="llama",
@@ -75,10 +73,17 @@ async def startup_event():
75
  context_length=512,
76
  gpu_layers=0 # CPU only
77
  )
78
- logger.info("Model loaded successfully")
79
  except Exception as e:
80
- logger.error(f"Failed to load model: {str(e)}")
81
- raise RuntimeError("Failed to initialize model")
 
 
 
 
 
 
 
 
82
 
83
  @app.get(
84
  "/health",
@@ -88,14 +93,10 @@ async def startup_event():
88
  )
89
  async def health_check():
90
  """Check if the model is loaded and get basic information"""
91
- if model is None:
92
- raise HTTPException(
93
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
94
- detail="Model not loaded"
95
- )
96
 
97
  return ModelInfo(
98
- status="ready",
99
  model_name="Llama-2-7B-Chat",
100
  model_path=str(MODEL_PATH),
101
  supported_styles=[
@@ -119,7 +120,7 @@ async def generate_poem(request: PoetryRequest):
119
  if model is None:
120
  raise HTTPException(
121
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
122
- detail="Model not loaded"
123
  )
124
 
125
  try:
@@ -159,6 +160,39 @@ async def generate_poem(request: PoetryRequest):
159
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
160
  detail=f"Failed to generate poem: {str(e)}"
161
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  if __name__ == "__main__":
164
  import uvicorn
 
1
  from fastapi import FastAPI, HTTPException, status
2
+ from pydantic import BaseModel, ConfigDict, Field
3
  from typing import Optional, List
4
  from ctransformers import AutoModelForCausalLM
5
  import time
6
  import logging
7
+ from app.config import MODEL_PATH, MODEL_URL
8
 
9
  # Configure logging
10
  logging.basicConfig(
 
49
  style: str
50
 
51
  class ModelInfo(BaseModel):
52
+ model_config = ConfigDict(protected_namespaces=())
53
+
54
  status: str
 
55
  model_path: str
56
+ model_name: str
57
  supported_styles: List[str]
58
  max_context_length: int
59
 
60
+ def initialize_model():
61
+ """Initialize the model and return it"""
62
+ if not MODEL_PATH.exists():
63
+ logger.error(f"Model not found at {MODEL_PATH}")
64
+ return None
 
 
 
 
 
65
 
66
+ try:
67
  logger.info(f"Loading model from {MODEL_PATH}")
68
+ return AutoModelForCausalLM.from_pretrained(
69
  str(MODEL_PATH.parent),
70
  model_file=MODEL_PATH.name,
71
  model_type="llama",
 
73
  context_length=512,
74
  gpu_layers=0 # CPU only
75
  )
 
76
  except Exception as e:
77
+ logger.error(f"Error loading model: {str(e)}")
78
+ return None
79
+
80
+ @app.on_event("startup")
81
+ async def startup_event():
82
+ """Initialize the model during startup"""
83
+ global model
84
+ model = initialize_model()
85
+ if model is None:
86
+ logger.warning("Model failed to load but service will start anyway")
87
 
88
  @app.get(
89
  "/health",
 
93
  )
94
  async def health_check():
95
  """Check if the model is loaded and get basic information"""
96
+ model_status = "ready" if model is not None else "not_loaded"
 
 
 
 
97
 
98
  return ModelInfo(
99
+ status=model_status,
100
  model_name="Llama-2-7B-Chat",
101
  model_path=str(MODEL_PATH),
102
  supported_styles=[
 
120
  if model is None:
121
  raise HTTPException(
122
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
123
+ detail="Model not loaded. Please check /health endpoint for status."
124
  )
125
 
126
  try:
 
160
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
161
  detail=f"Failed to generate poem: {str(e)}"
162
  )
163
+
164
+ def download_model():
165
+ """Download the model if it doesn't exist"""
166
+ import requests
167
+ from tqdm import tqdm
168
+
169
+ if MODEL_PATH.exists():
170
+ logger.info(f"Model already exists at {MODEL_PATH}")
171
+ return
172
+
173
+ logger.info(f"Downloading model to {MODEL_PATH}")
174
+ try:
175
+ response = requests.get(MODEL_URL, stream=True)
176
+ response.raise_for_status()
177
+ total_size = int(response.headers.get('content-length', 0))
178
+
179
+ with open(MODEL_PATH, 'wb') as file, tqdm(
180
+ desc="Downloading",
181
+ total=total_size,
182
+ unit='iB',
183
+ unit_scale=True,
184
+ unit_divisor=1024,
185
+ ) as pbar:
186
+ for data in response.iter_content(chunk_size=1024):
187
+ size = file.write(data)
188
+ pbar.update(size)
189
+
190
+ logger.info("Model downloaded successfully")
191
+ except Exception as e:
192
+ logger.error(f"Error downloading model: {str(e)}")
193
+ if MODEL_PATH.exists():
194
+ MODEL_PATH.unlink()
195
+ raise
196
 
197
  if __name__ == "__main__":
198
  import uvicorn