Spaces:
Sleeping
Sleeping
Commit
·
110ce35
1
Parent(s):
4bb8b78
Update model configuration and enhance initialization logic; adjust BASE_DIR for container, implement model download functionality, and improve health check response
Browse files- app/config.py +2 -5
- main.py +58 -24
app/config.py
CHANGED
@@ -1,16 +1,13 @@
|
|
1 |
import os
|
2 |
from pathlib import Path
|
3 |
|
4 |
-
# Base project directory
|
5 |
-
BASE_DIR = Path(
|
6 |
-
|
7 |
-
# Model settings
|
8 |
MODEL_DIR = BASE_DIR / "models"
|
9 |
MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
|
10 |
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
11 |
|
12 |
# Ensure model directory exists
|
13 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
14 |
-
|
15 |
# Model download URL
|
16 |
MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
|
|
|
1 |
import os
|
2 |
from pathlib import Path
|
3 |
|
4 |
+
# Base project directory (adjusted for container environment)
|
5 |
+
BASE_DIR = Path("/app")
|
|
|
|
|
6 |
MODEL_DIR = BASE_DIR / "models"
|
7 |
MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
|
8 |
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
9 |
|
10 |
# Ensure model directory exists
|
11 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
12 |
# Model download URL
|
13 |
MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
|
main.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from fastapi import FastAPI, HTTPException, status
|
2 |
-
from pydantic import BaseModel, Field
|
3 |
from typing import Optional, List
|
4 |
from ctransformers import AutoModelForCausalLM
|
5 |
import time
|
6 |
import logging
|
7 |
-
from app.config import MODEL_PATH
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(
|
@@ -49,25 +49,23 @@ class PoetryResponse(BaseModel):
|
|
49 |
style: str
|
50 |
|
51 |
class ModelInfo(BaseModel):
|
|
|
|
|
52 |
status: str
|
53 |
-
model_name: str
|
54 |
model_path: str
|
|
|
55 |
supported_styles: List[str]
|
56 |
max_context_length: int
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
if not MODEL_PATH.exists():
|
64 |
-
raise FileNotFoundError(
|
65 |
-
f"Model file not found at {MODEL_PATH}. "
|
66 |
-
"Please run download_model.py first."
|
67 |
-
)
|
68 |
|
|
|
69 |
logger.info(f"Loading model from {MODEL_PATH}")
|
70 |
-
|
71 |
str(MODEL_PATH.parent),
|
72 |
model_file=MODEL_PATH.name,
|
73 |
model_type="llama",
|
@@ -75,10 +73,17 @@ async def startup_event():
|
|
75 |
context_length=512,
|
76 |
gpu_layers=0 # CPU only
|
77 |
)
|
78 |
-
logger.info("Model loaded successfully")
|
79 |
except Exception as e:
|
80 |
-
logger.error(f"
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
@app.get(
|
84 |
"/health",
|
@@ -88,14 +93,10 @@ async def startup_event():
|
|
88 |
)
|
89 |
async def health_check():
|
90 |
"""Check if the model is loaded and get basic information"""
|
91 |
-
if model is None
|
92 |
-
raise HTTPException(
|
93 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
94 |
-
detail="Model not loaded"
|
95 |
-
)
|
96 |
|
97 |
return ModelInfo(
|
98 |
-
status=
|
99 |
model_name="Llama-2-7B-Chat",
|
100 |
model_path=str(MODEL_PATH),
|
101 |
supported_styles=[
|
@@ -119,7 +120,7 @@ async def generate_poem(request: PoetryRequest):
|
|
119 |
if model is None:
|
120 |
raise HTTPException(
|
121 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
122 |
-
detail="Model not loaded"
|
123 |
)
|
124 |
|
125 |
try:
|
@@ -159,6 +160,39 @@ async def generate_poem(request: PoetryRequest):
|
|
159 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
160 |
detail=f"Failed to generate poem: {str(e)}"
|
161 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
if __name__ == "__main__":
|
164 |
import uvicorn
|
|
|
1 |
from fastapi import FastAPI, HTTPException, status
|
2 |
+
from pydantic import BaseModel, ConfigDict, Field
|
3 |
from typing import Optional, List
|
4 |
from ctransformers import AutoModelForCausalLM
|
5 |
import time
|
6 |
import logging
|
7 |
+
from app.config import MODEL_PATH, MODEL_URL
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(
|
|
|
49 |
style: str
|
50 |
|
51 |
class ModelInfo(BaseModel):
|
52 |
+
model_config = ConfigDict(protected_namespaces=())
|
53 |
+
|
54 |
status: str
|
|
|
55 |
model_path: str
|
56 |
+
model_name: str
|
57 |
supported_styles: List[str]
|
58 |
max_context_length: int
|
59 |
|
60 |
+
def initialize_model():
|
61 |
+
"""Initialize the model and return it"""
|
62 |
+
if not MODEL_PATH.exists():
|
63 |
+
logger.error(f"Model not found at {MODEL_PATH}")
|
64 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
try:
|
67 |
logger.info(f"Loading model from {MODEL_PATH}")
|
68 |
+
return AutoModelForCausalLM.from_pretrained(
|
69 |
str(MODEL_PATH.parent),
|
70 |
model_file=MODEL_PATH.name,
|
71 |
model_type="llama",
|
|
|
73 |
context_length=512,
|
74 |
gpu_layers=0 # CPU only
|
75 |
)
|
|
|
76 |
except Exception as e:
|
77 |
+
logger.error(f"Error loading model: {str(e)}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
@app.on_event("startup")
|
81 |
+
async def startup_event():
|
82 |
+
"""Initialize the model during startup"""
|
83 |
+
global model
|
84 |
+
model = initialize_model()
|
85 |
+
if model is None:
|
86 |
+
logger.warning("Model failed to load but service will start anyway")
|
87 |
|
88 |
@app.get(
|
89 |
"/health",
|
|
|
93 |
)
|
94 |
async def health_check():
|
95 |
"""Check if the model is loaded and get basic information"""
|
96 |
+
model_status = "ready" if model is not None else "not_loaded"
|
|
|
|
|
|
|
|
|
97 |
|
98 |
return ModelInfo(
|
99 |
+
status=model_status,
|
100 |
model_name="Llama-2-7B-Chat",
|
101 |
model_path=str(MODEL_PATH),
|
102 |
supported_styles=[
|
|
|
120 |
if model is None:
|
121 |
raise HTTPException(
|
122 |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
123 |
+
detail="Model not loaded. Please check /health endpoint for status."
|
124 |
)
|
125 |
|
126 |
try:
|
|
|
160 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
161 |
detail=f"Failed to generate poem: {str(e)}"
|
162 |
)
|
163 |
+
|
164 |
+
def download_model():
|
165 |
+
"""Download the model if it doesn't exist"""
|
166 |
+
import requests
|
167 |
+
from tqdm import tqdm
|
168 |
+
|
169 |
+
if MODEL_PATH.exists():
|
170 |
+
logger.info(f"Model already exists at {MODEL_PATH}")
|
171 |
+
return
|
172 |
+
|
173 |
+
logger.info(f"Downloading model to {MODEL_PATH}")
|
174 |
+
try:
|
175 |
+
response = requests.get(MODEL_URL, stream=True)
|
176 |
+
response.raise_for_status()
|
177 |
+
total_size = int(response.headers.get('content-length', 0))
|
178 |
+
|
179 |
+
with open(MODEL_PATH, 'wb') as file, tqdm(
|
180 |
+
desc="Downloading",
|
181 |
+
total=total_size,
|
182 |
+
unit='iB',
|
183 |
+
unit_scale=True,
|
184 |
+
unit_divisor=1024,
|
185 |
+
) as pbar:
|
186 |
+
for data in response.iter_content(chunk_size=1024):
|
187 |
+
size = file.write(data)
|
188 |
+
pbar.update(size)
|
189 |
+
|
190 |
+
logger.info("Model downloaded successfully")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Error downloading model: {str(e)}")
|
193 |
+
if MODEL_PATH.exists():
|
194 |
+
MODEL_PATH.unlink()
|
195 |
+
raise
|
196 |
|
197 |
if __name__ == "__main__":
|
198 |
import uvicorn
|