anderson-ufrj
commited on
Commit
·
796f99b
1
Parent(s):
f93caf8
feat(ml): implement comprehensive ML Pipeline with versioning and A/B testing
Browse files- Add MLTrainingPipeline with support for multiple algorithms
- Implement model versioning with MLflow tracking
- Create A/B testing framework with multiple allocation strategies
- Add API endpoints for model training and management
- Support Thompson sampling and epsilon-greedy strategies
- Include statistical significance testing for A/B tests
- Add comprehensive unit tests for ML components
- ROADMAP_MELHORIAS_2025.md +1 -5
- pyproject.toml +2 -0
- src/api/app.py +6 -0
- src/api/routes/ml_pipeline.py +451 -0
- src/ml/__init__.py +31 -13
- src/ml/ab_testing.py +512 -0
- src/ml/training_pipeline.py +466 -756
- tests/unit/ml/__init__.py +0 -0
- tests/unit/ml/test_training_pipeline.py +369 -0
ROADMAP_MELHORIAS_2025.md
CHANGED
|
@@ -251,10 +251,6 @@ Este documento apresenta um roadmap estruturado para melhorias no backend do Cid
|
|
| 251 |
- [ ] Comentários e anotações
|
| 252 |
- [ ] Workspaces compartilhados
|
| 253 |
|
| 254 |
-
2. **Mobile & PWA**
|
| 255 |
-
- [ ] Progressive Web App
|
| 256 |
-
- [ ] Offline capabilities
|
| 257 |
-
- [ ] Push notifications
|
| 258 |
|
| 259 |
**Entregáveis**: Platform enterprise-ready
|
| 260 |
|
|
@@ -325,4 +321,4 @@ Este documento apresenta um roadmap estruturado para melhorias no backend do Cid
|
|
| 325 |
|
| 326 |
---
|
| 327 |
|
| 328 |
-
*Este roadmap é um documento vivo e deve ser revisado a cada sprint com base no feedback e aprendizados.*
|
|
|
|
| 251 |
- [ ] Comentários e anotações
|
| 252 |
- [ ] Workspaces compartilhados
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
**Entregáveis**: Platform enterprise-ready
|
| 256 |
|
|
|
|
| 321 |
|
| 322 |
---
|
| 323 |
|
| 324 |
+
*Este roadmap é um documento vivo e deve ser revisado a cada sprint com base no feedback e aprendizados.*
|
pyproject.toml
CHANGED
|
@@ -61,6 +61,8 @@ dependencies = [
|
|
| 61 |
"hdbscan>=0.8.33",
|
| 62 |
"shap>=0.43.0",
|
| 63 |
"lime>=0.2.0.1",
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Async processing
|
| 66 |
"celery[redis]>=5.3.4",
|
|
|
|
| 61 |
"hdbscan>=0.8.33",
|
| 62 |
"shap>=0.43.0",
|
| 63 |
"lime>=0.2.0.1",
|
| 64 |
+
"mlflow>=2.9.0",
|
| 65 |
+
"joblib>=1.3.2",
|
| 66 |
|
| 67 |
# Async processing
|
| 68 |
"celery[redis]>=5.3.4",
|
src/api/app.py
CHANGED
|
@@ -521,6 +521,12 @@ app.include_router(
|
|
| 521 |
tags=["Geographic Data"]
|
| 522 |
)
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
# Global exception handler
|
| 526 |
@app.exception_handler(CidadaoAIError)
|
|
|
|
| 521 |
tags=["Geographic Data"]
|
| 522 |
)
|
| 523 |
|
| 524 |
+
from src.api.routes import ml_pipeline
|
| 525 |
+
app.include_router(
|
| 526 |
+
ml_pipeline.router,
|
| 527 |
+
tags=["ML Pipeline"]
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
|
| 531 |
# Global exception handler
|
| 532 |
@app.exception_handler(CidadaoAIError)
|
src/api/routes/ml_pipeline.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Pipeline API Routes
|
| 3 |
+
|
| 4 |
+
This module provides API endpoints for training, versioning, and
|
| 5 |
+
A/B testing ML models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from src.api.dependencies import get_current_user
|
| 14 |
+
from src.ml.training_pipeline import get_training_pipeline
|
| 15 |
+
from src.ml.ab_testing import get_ab_testing, TrafficAllocationStrategy
|
| 16 |
+
from src.core import get_logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
router = APIRouter(prefix="/api/v1/ml")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TrainModelRequest(BaseModel):
|
| 24 |
+
"""Request model for training ML models."""
|
| 25 |
+
model_type: str = Field(..., description="Type of model (anomaly, fraud, pattern)")
|
| 26 |
+
algorithm: str = Field(..., description="Algorithm to use (isolation_forest, etc)")
|
| 27 |
+
dataset_id: Optional[str] = Field(None, description="Dataset identifier")
|
| 28 |
+
hyperparameters: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
| 29 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PromoteModelRequest(BaseModel):
|
| 33 |
+
"""Request model for promoting models."""
|
| 34 |
+
model_id: str = Field(..., description="Model identifier")
|
| 35 |
+
version: int = Field(..., description="Model version")
|
| 36 |
+
status: str = Field("production", description="Target status")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ABTestRequest(BaseModel):
|
| 40 |
+
"""Request model for creating A/B tests."""
|
| 41 |
+
test_name: str = Field(..., description="Unique test name")
|
| 42 |
+
model_a_id: str = Field(..., description="Model A identifier")
|
| 43 |
+
model_a_version: Optional[int] = Field(None, description="Model A version")
|
| 44 |
+
model_b_id: str = Field(..., description="Model B identifier")
|
| 45 |
+
model_b_version: Optional[int] = Field(None, description="Model B version")
|
| 46 |
+
allocation_strategy: str = Field("random", description="Allocation strategy")
|
| 47 |
+
traffic_split: List[float] = Field([0.5, 0.5], description="Traffic split")
|
| 48 |
+
success_metric: str = Field("f1_score", description="Success metric")
|
| 49 |
+
minimum_sample_size: int = Field(1000, description="Minimum samples")
|
| 50 |
+
significance_level: float = Field(0.05, description="Significance level")
|
| 51 |
+
auto_stop: bool = Field(True, description="Auto stop on winner")
|
| 52 |
+
duration_hours: Optional[int] = Field(None, description="Max duration")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RecordPredictionRequest(BaseModel):
|
| 56 |
+
"""Request model for recording predictions in A/B test."""
|
| 57 |
+
model_selection: str = Field(..., description="model_a or model_b")
|
| 58 |
+
success: bool = Field(..., description="Prediction success")
|
| 59 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@router.post("/train", response_model=Dict[str, Any])
|
| 63 |
+
async def train_model(
|
| 64 |
+
request: TrainModelRequest,
|
| 65 |
+
background_tasks: BackgroundTasks,
|
| 66 |
+
current_user: Dict = Depends(get_current_user)
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Train a new ML model.
|
| 70 |
+
|
| 71 |
+
This endpoint initiates model training with the specified algorithm
|
| 72 |
+
and parameters. Training runs asynchronously in the background.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
pipeline = await get_training_pipeline()
|
| 76 |
+
|
| 77 |
+
# For demo purposes, generate synthetic training data
|
| 78 |
+
# In production, this would load from dataset_id
|
| 79 |
+
if request.model_type == "anomaly":
|
| 80 |
+
# Generate anomaly detection data
|
| 81 |
+
n_samples = 1000
|
| 82 |
+
n_features = 10
|
| 83 |
+
X_train = np.random.randn(n_samples, n_features)
|
| 84 |
+
# Add some anomalies
|
| 85 |
+
anomalies = np.random.randn(50, n_features) * 3
|
| 86 |
+
X_train = np.vstack([X_train, anomalies])
|
| 87 |
+
y_train = None # Unsupervised
|
| 88 |
+
elif request.model_type == "fraud":
|
| 89 |
+
# Generate fraud detection data
|
| 90 |
+
n_samples = 1000
|
| 91 |
+
n_features = 15
|
| 92 |
+
X_train = np.random.randn(n_samples, n_features)
|
| 93 |
+
y_train = np.random.choice([0, 1], size=n_samples, p=[0.95, 0.05])
|
| 94 |
+
else:
|
| 95 |
+
# Pattern recognition data
|
| 96 |
+
n_samples = 800
|
| 97 |
+
n_features = 20
|
| 98 |
+
X_train = np.random.randn(n_samples, n_features)
|
| 99 |
+
y_train = np.random.choice([0, 1, 2], size=n_samples)
|
| 100 |
+
|
| 101 |
+
# Start training
|
| 102 |
+
result = await pipeline.train_model(
|
| 103 |
+
model_type=request.model_type,
|
| 104 |
+
algorithm=request.algorithm,
|
| 105 |
+
X_train=X_train,
|
| 106 |
+
y_train=y_train,
|
| 107 |
+
hyperparameters=request.hyperparameters,
|
| 108 |
+
metadata={
|
| 109 |
+
**request.metadata,
|
| 110 |
+
"user_id": current_user["id"],
|
| 111 |
+
"dataset_id": request.dataset_id
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Training failed: {str(e)}")
|
| 119 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@router.get("/models", response_model=List[Dict[str, Any]])
|
| 123 |
+
async def list_models(
|
| 124 |
+
model_type: Optional[str] = None,
|
| 125 |
+
current_user: Dict = Depends(get_current_user)
|
| 126 |
+
):
|
| 127 |
+
"""List all available models with their versions."""
|
| 128 |
+
try:
|
| 129 |
+
pipeline = await get_training_pipeline()
|
| 130 |
+
|
| 131 |
+
# Get models from registry
|
| 132 |
+
models = []
|
| 133 |
+
for model_id, registry in pipeline.model_registry.items():
|
| 134 |
+
if model_type and not model_id.startswith(model_type):
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
models.append({
|
| 138 |
+
"model_id": model_id,
|
| 139 |
+
"versions": len(registry["versions"]),
|
| 140 |
+
"latest_version": max(
|
| 141 |
+
(v["version"] for v in registry["versions"]),
|
| 142 |
+
default=0
|
| 143 |
+
),
|
| 144 |
+
"created_at": registry["created_at"],
|
| 145 |
+
"production_version": next(
|
| 146 |
+
(v["version"] for v in registry["versions"]
|
| 147 |
+
if v.get("status") == "production"),
|
| 148 |
+
None
|
| 149 |
+
)
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
return models
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"Failed to list models: {str(e)}")
|
| 156 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@router.get("/models/{model_id}/versions", response_model=List[Dict[str, Any]])
|
| 160 |
+
async def list_model_versions(
|
| 161 |
+
model_id: str,
|
| 162 |
+
current_user: Dict = Depends(get_current_user)
|
| 163 |
+
):
|
| 164 |
+
"""List all versions of a specific model."""
|
| 165 |
+
try:
|
| 166 |
+
pipeline = await get_training_pipeline()
|
| 167 |
+
|
| 168 |
+
if model_id not in pipeline.model_registry:
|
| 169 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 170 |
+
|
| 171 |
+
versions = []
|
| 172 |
+
for version in pipeline.model_registry[model_id]["versions"]:
|
| 173 |
+
versions.append({
|
| 174 |
+
"version": version["version"],
|
| 175 |
+
"status": version["status"],
|
| 176 |
+
"metrics": version["metrics"],
|
| 177 |
+
"created_at": version["created_at"],
|
| 178 |
+
"promoted_at": version.get("promoted_at")
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return versions
|
| 182 |
+
|
| 183 |
+
except HTTPException:
|
| 184 |
+
raise
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Failed to list versions: {str(e)}")
|
| 187 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@router.get("/models/{model_id}/metrics", response_model=Dict[str, Any])
|
| 191 |
+
async def get_model_metrics(
|
| 192 |
+
model_id: str,
|
| 193 |
+
version: Optional[int] = None,
|
| 194 |
+
current_user: Dict = Depends(get_current_user)
|
| 195 |
+
):
|
| 196 |
+
"""Get metrics for a specific model version."""
|
| 197 |
+
try:
|
| 198 |
+
pipeline = await get_training_pipeline()
|
| 199 |
+
metrics = await pipeline.get_model_metrics(model_id, version)
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"model_id": model_id,
|
| 203 |
+
"version": version or "latest",
|
| 204 |
+
"metrics": metrics
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Failed to get metrics: {str(e)}")
|
| 209 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@router.post("/models/promote", response_model=Dict[str, Any])
|
| 213 |
+
async def promote_model(
|
| 214 |
+
request: PromoteModelRequest,
|
| 215 |
+
current_user: Dict = Depends(get_current_user)
|
| 216 |
+
):
|
| 217 |
+
"""Promote a model version to production."""
|
| 218 |
+
try:
|
| 219 |
+
pipeline = await get_training_pipeline()
|
| 220 |
+
success = await pipeline.promote_model(
|
| 221 |
+
request.model_id,
|
| 222 |
+
request.version,
|
| 223 |
+
request.status
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if not success:
|
| 227 |
+
raise HTTPException(status_code=500, detail="Promotion failed")
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
"success": True,
|
| 231 |
+
"model_id": request.model_id,
|
| 232 |
+
"version": request.version,
|
| 233 |
+
"status": request.status,
|
| 234 |
+
"message": f"Model promoted to {request.status}"
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
except HTTPException:
|
| 238 |
+
raise
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Failed to promote model: {str(e)}")
|
| 241 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@router.post("/ab-test/create", response_model=Dict[str, Any])
|
| 245 |
+
async def create_ab_test(
|
| 246 |
+
request: ABTestRequest,
|
| 247 |
+
current_user: Dict = Depends(get_current_user)
|
| 248 |
+
):
|
| 249 |
+
"""Create a new A/B test."""
|
| 250 |
+
try:
|
| 251 |
+
ab_framework = await get_ab_testing()
|
| 252 |
+
|
| 253 |
+
# Validate allocation strategy
|
| 254 |
+
try:
|
| 255 |
+
strategy = TrafficAllocationStrategy(request.allocation_strategy)
|
| 256 |
+
except ValueError:
|
| 257 |
+
raise HTTPException(
|
| 258 |
+
status_code=400,
|
| 259 |
+
detail=f"Invalid allocation strategy: {request.allocation_strategy}"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
test_config = await ab_framework.create_test(
|
| 263 |
+
test_name=request.test_name,
|
| 264 |
+
model_a=(request.model_a_id, request.model_a_version),
|
| 265 |
+
model_b=(request.model_b_id, request.model_b_version),
|
| 266 |
+
allocation_strategy=strategy,
|
| 267 |
+
traffic_split=tuple(request.traffic_split),
|
| 268 |
+
success_metric=request.success_metric,
|
| 269 |
+
minimum_sample_size=request.minimum_sample_size,
|
| 270 |
+
significance_level=request.significance_level,
|
| 271 |
+
auto_stop=request.auto_stop,
|
| 272 |
+
duration_hours=request.duration_hours
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return test_config
|
| 276 |
+
|
| 277 |
+
except ValueError as e:
|
| 278 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"Failed to create A/B test: {str(e)}")
|
| 281 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@router.post("/ab-test/{test_name}/start", response_model=Dict[str, Any])
|
| 285 |
+
async def start_ab_test(
|
| 286 |
+
test_name: str,
|
| 287 |
+
current_user: Dict = Depends(get_current_user)
|
| 288 |
+
):
|
| 289 |
+
"""Start an A/B test."""
|
| 290 |
+
try:
|
| 291 |
+
ab_framework = await get_ab_testing()
|
| 292 |
+
success = await ab_framework.start_test(test_name)
|
| 293 |
+
|
| 294 |
+
if not success:
|
| 295 |
+
raise HTTPException(status_code=500, detail="Failed to start test")
|
| 296 |
+
|
| 297 |
+
return {
|
| 298 |
+
"success": True,
|
| 299 |
+
"test_name": test_name,
|
| 300 |
+
"message": "A/B test started"
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
except ValueError as e:
|
| 304 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 305 |
+
except Exception as e:
|
| 306 |
+
logger.error(f"Failed to start A/B test: {str(e)}")
|
| 307 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@router.get("/ab-test/{test_name}/allocate", response_model=Dict[str, Any])
|
| 311 |
+
async def allocate_model_for_test(
|
| 312 |
+
test_name: str,
|
| 313 |
+
user_id: Optional[str] = None
|
| 314 |
+
):
|
| 315 |
+
"""Get model allocation for a user in an A/B test."""
|
| 316 |
+
try:
|
| 317 |
+
ab_framework = await get_ab_testing()
|
| 318 |
+
model_id, version = await ab_framework.allocate_model(test_name, user_id)
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
"model_id": model_id,
|
| 322 |
+
"version": version,
|
| 323 |
+
"test_name": test_name,
|
| 324 |
+
"user_id": user_id
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
except ValueError as e:
|
| 328 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"Failed to allocate model: {str(e)}")
|
| 331 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@router.post("/ab-test/{test_name}/record", response_model=Dict[str, Any])
|
| 335 |
+
async def record_prediction(
|
| 336 |
+
test_name: str,
|
| 337 |
+
request: RecordPredictionRequest
|
| 338 |
+
):
|
| 339 |
+
"""Record a prediction result for an A/B test."""
|
| 340 |
+
try:
|
| 341 |
+
ab_framework = await get_ab_testing()
|
| 342 |
+
await ab_framework.record_prediction(
|
| 343 |
+
test_name,
|
| 344 |
+
request.model_selection,
|
| 345 |
+
request.success,
|
| 346 |
+
request.metadata
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return {
|
| 350 |
+
"success": True,
|
| 351 |
+
"test_name": test_name,
|
| 352 |
+
"model_selection": request.model_selection
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
except ValueError as e:
|
| 356 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Failed to record prediction: {str(e)}")
|
| 359 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@router.get("/ab-test/{test_name}/status", response_model=Dict[str, Any])
|
| 363 |
+
async def get_ab_test_status(
|
| 364 |
+
test_name: str,
|
| 365 |
+
current_user: Dict = Depends(get_current_user)
|
| 366 |
+
):
|
| 367 |
+
"""Get current status and results of an A/B test."""
|
| 368 |
+
try:
|
| 369 |
+
ab_framework = await get_ab_testing()
|
| 370 |
+
status = await ab_framework.get_test_status(test_name)
|
| 371 |
+
|
| 372 |
+
# Include latest analysis if available
|
| 373 |
+
if "latest_analysis" in status:
|
| 374 |
+
status["analysis"] = status["latest_analysis"]
|
| 375 |
+
|
| 376 |
+
return status
|
| 377 |
+
|
| 378 |
+
except ValueError as e:
|
| 379 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"Failed to get test status: {str(e)}")
|
| 382 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
@router.post("/ab-test/{test_name}/stop", response_model=Dict[str, Any])
|
| 386 |
+
async def stop_ab_test(
|
| 387 |
+
test_name: str,
|
| 388 |
+
reason: str = "Manual stop",
|
| 389 |
+
current_user: Dict = Depends(get_current_user)
|
| 390 |
+
):
|
| 391 |
+
"""Stop an A/B test."""
|
| 392 |
+
try:
|
| 393 |
+
ab_framework = await get_ab_testing()
|
| 394 |
+
success = await ab_framework.stop_test(test_name, reason)
|
| 395 |
+
|
| 396 |
+
if not success:
|
| 397 |
+
raise HTTPException(status_code=500, detail="Failed to stop test")
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
"success": True,
|
| 401 |
+
"test_name": test_name,
|
| 402 |
+
"message": f"A/B test stopped: {reason}"
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
except ValueError as e:
|
| 406 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.error(f"Failed to stop A/B test: {str(e)}")
|
| 409 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@router.post("/ab-test/{test_name}/promote-winner", response_model=Dict[str, Any])
|
| 413 |
+
async def promote_ab_test_winner(
|
| 414 |
+
test_name: str,
|
| 415 |
+
current_user: Dict = Depends(get_current_user)
|
| 416 |
+
):
|
| 417 |
+
"""Promote the winning model from an A/B test to production."""
|
| 418 |
+
try:
|
| 419 |
+
ab_framework = await get_ab_testing()
|
| 420 |
+
success = await ab_framework.promote_winner(test_name)
|
| 421 |
+
|
| 422 |
+
if not success:
|
| 423 |
+
raise HTTPException(status_code=500, detail="Failed to promote winner")
|
| 424 |
+
|
| 425 |
+
return {
|
| 426 |
+
"success": True,
|
| 427 |
+
"test_name": test_name,
|
| 428 |
+
"message": "Winner promoted to production"
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
except ValueError as e:
|
| 432 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Failed to promote winner: {str(e)}")
|
| 435 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@router.get("/ab-test/active", response_model=List[Dict[str, Any]])
|
| 439 |
+
async def list_active_ab_tests(
|
| 440 |
+
current_user: Dict = Depends(get_current_user)
|
| 441 |
+
):
|
| 442 |
+
"""List all active A/B tests."""
|
| 443 |
+
try:
|
| 444 |
+
ab_framework = await get_ab_testing()
|
| 445 |
+
active_tests = await ab_framework.list_active_tests()
|
| 446 |
+
|
| 447 |
+
return active_tests
|
| 448 |
+
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Failed to list active tests: {str(e)}")
|
| 451 |
+
raise HTTPException(status_code=500, detail=str(e))
|
src/ml/__init__.py
CHANGED
|
@@ -1,19 +1,37 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
This module provides ML capabilities including:
|
| 4 |
-
- Anomaly detection algorithms
|
| 5 |
-
- Pattern analysis and correlation detection
|
| 6 |
-
- Predictive models for spending analysis
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
-
from .
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
__all__ = [
|
| 16 |
-
|
| 17 |
-
"
|
| 18 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
]
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Pipeline Module
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
This module provides machine learning capabilities including:
|
| 5 |
+
- Model training pipeline
|
| 6 |
+
- Model versioning
|
| 7 |
+
- A/B testing framework
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
from src.ml.training_pipeline import (
|
| 11 |
+
MLTrainingPipeline,
|
| 12 |
+
training_pipeline,
|
| 13 |
+
get_training_pipeline
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from src.ml.ab_testing import (
|
| 17 |
+
ABTestFramework,
|
| 18 |
+
ABTestStatus,
|
| 19 |
+
TrafficAllocationStrategy,
|
| 20 |
+
ab_testing,
|
| 21 |
+
get_ab_testing
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
|
| 25 |
__all__ = [
|
| 26 |
+
# Training Pipeline
|
| 27 |
+
"MLTrainingPipeline",
|
| 28 |
+
"training_pipeline",
|
| 29 |
+
"get_training_pipeline",
|
| 30 |
+
|
| 31 |
+
# A/B Testing
|
| 32 |
+
"ABTestFramework",
|
| 33 |
+
"ABTestStatus",
|
| 34 |
+
"TrafficAllocationStrategy",
|
| 35 |
+
"ab_testing",
|
| 36 |
+
"get_ab_testing"
|
| 37 |
]
|
src/ml/ab_testing.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A/B Testing Framework for ML Models
|
| 3 |
+
|
| 4 |
+
This module provides A/B testing capabilities for comparing model
|
| 5 |
+
performance in production environments.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
| 13 |
+
from enum import Enum
|
| 14 |
+
import numpy as np
|
| 15 |
+
from scipy import stats
|
| 16 |
+
|
| 17 |
+
from src.core import get_logger
|
| 18 |
+
from src.infrastructure.cache.redis_client import get_redis_client
|
| 19 |
+
from src.ml.training_pipeline import training_pipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ABTestStatus(Enum):
|
| 26 |
+
"""Status of an A/B test."""
|
| 27 |
+
DRAFT = "draft"
|
| 28 |
+
RUNNING = "running"
|
| 29 |
+
PAUSED = "paused"
|
| 30 |
+
COMPLETED = "completed"
|
| 31 |
+
STOPPED = "stopped"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TrafficAllocationStrategy(Enum):
|
| 35 |
+
"""Strategy for allocating traffic between models."""
|
| 36 |
+
RANDOM = "random"
|
| 37 |
+
WEIGHTED = "weighted"
|
| 38 |
+
EPSILON_GREEDY = "epsilon_greedy"
|
| 39 |
+
THOMPSON_SAMPLING = "thompson_sampling"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ABTestFramework:
|
| 43 |
+
"""
|
| 44 |
+
A/B Testing framework for ML models.
|
| 45 |
+
|
| 46 |
+
Features:
|
| 47 |
+
- Multiple allocation strategies
|
| 48 |
+
- Statistical significance testing
|
| 49 |
+
- Real-time performance tracking
|
| 50 |
+
- Automatic winner selection
|
| 51 |
+
- Gradual rollout support
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self):
|
| 55 |
+
"""Initialize the A/B testing framework."""
|
| 56 |
+
self.active_tests = {}
|
| 57 |
+
self.test_results = {}
|
| 58 |
+
|
| 59 |
+
async def create_test(
|
| 60 |
+
self,
|
| 61 |
+
test_name: str,
|
| 62 |
+
model_a: Tuple[str, Optional[int]], # (model_id, version)
|
| 63 |
+
model_b: Tuple[str, Optional[int]],
|
| 64 |
+
allocation_strategy: TrafficAllocationStrategy = TrafficAllocationStrategy.RANDOM,
|
| 65 |
+
traffic_split: Tuple[float, float] = (0.5, 0.5),
|
| 66 |
+
success_metric: str = "f1_score",
|
| 67 |
+
minimum_sample_size: int = 1000,
|
| 68 |
+
significance_level: float = 0.05,
|
| 69 |
+
auto_stop: bool = True,
|
| 70 |
+
duration_hours: Optional[int] = None
|
| 71 |
+
) -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
Create a new A/B test.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
test_name: Unique name for the test
|
| 77 |
+
model_a: Model A (control) - (model_id, version)
|
| 78 |
+
model_b: Model B (treatment) - (model_id, version)
|
| 79 |
+
allocation_strategy: Traffic allocation strategy
|
| 80 |
+
traffic_split: Traffic split between models (must sum to 1.0)
|
| 81 |
+
success_metric: Metric to optimize
|
| 82 |
+
minimum_sample_size: Minimum samples before analysis
|
| 83 |
+
significance_level: Statistical significance threshold
|
| 84 |
+
auto_stop: Automatically stop when winner found
|
| 85 |
+
duration_hours: Maximum test duration
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Test configuration
|
| 89 |
+
"""
|
| 90 |
+
if test_name in self.active_tests:
|
| 91 |
+
raise ValueError(f"Test {test_name} already exists")
|
| 92 |
+
|
| 93 |
+
if abs(sum(traffic_split) - 1.0) > 0.001:
|
| 94 |
+
raise ValueError("Traffic split must sum to 1.0")
|
| 95 |
+
|
| 96 |
+
# Load models to verify they exist
|
| 97 |
+
await training_pipeline.load_model(*model_a)
|
| 98 |
+
await training_pipeline.load_model(*model_b)
|
| 99 |
+
|
| 100 |
+
test_config = {
|
| 101 |
+
"test_id": f"ab_test_{test_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 102 |
+
"test_name": test_name,
|
| 103 |
+
"model_a": {"model_id": model_a[0], "version": model_a[1]},
|
| 104 |
+
"model_b": {"model_id": model_b[0], "version": model_b[1]},
|
| 105 |
+
"allocation_strategy": allocation_strategy.value,
|
| 106 |
+
"traffic_split": traffic_split,
|
| 107 |
+
"success_metric": success_metric,
|
| 108 |
+
"minimum_sample_size": minimum_sample_size,
|
| 109 |
+
"significance_level": significance_level,
|
| 110 |
+
"auto_stop": auto_stop,
|
| 111 |
+
"status": ABTestStatus.DRAFT.value,
|
| 112 |
+
"created_at": datetime.now().isoformat(),
|
| 113 |
+
"start_time": None,
|
| 114 |
+
"end_time": None,
|
| 115 |
+
"duration_hours": duration_hours,
|
| 116 |
+
"results": {
|
| 117 |
+
"model_a": {"predictions": 0, "successes": 0, "metrics": {}},
|
| 118 |
+
"model_b": {"predictions": 0, "successes": 0, "metrics": {}}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Initialize allocation strategy specific params
|
| 123 |
+
if allocation_strategy == TrafficAllocationStrategy.EPSILON_GREEDY:
|
| 124 |
+
test_config["epsilon"] = 0.1 # 10% exploration
|
| 125 |
+
elif allocation_strategy == TrafficAllocationStrategy.THOMPSON_SAMPLING:
|
| 126 |
+
test_config["thompson_params"] = {
|
| 127 |
+
"model_a": {"alpha": 1, "beta": 1},
|
| 128 |
+
"model_b": {"alpha": 1, "beta": 1}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
self.active_tests[test_name] = test_config
|
| 132 |
+
|
| 133 |
+
# Save to Redis
|
| 134 |
+
await self._save_test_config(test_config)
|
| 135 |
+
|
| 136 |
+
logger.info(f"Created A/B test: {test_name}")
|
| 137 |
+
return test_config
|
| 138 |
+
|
| 139 |
+
async def start_test(self, test_name: str) -> bool:
|
| 140 |
+
"""Start an A/B test."""
|
| 141 |
+
if test_name not in self.active_tests:
|
| 142 |
+
# Try to load from Redis
|
| 143 |
+
test_config = await self._load_test_config(test_name)
|
| 144 |
+
if not test_config:
|
| 145 |
+
raise ValueError(f"Test {test_name} not found")
|
| 146 |
+
self.active_tests[test_name] = test_config
|
| 147 |
+
|
| 148 |
+
test_config = self.active_tests[test_name]
|
| 149 |
+
|
| 150 |
+
if test_config["status"] not in [ABTestStatus.DRAFT.value, ABTestStatus.PAUSED.value]:
|
| 151 |
+
raise ValueError(f"Cannot start test in status {test_config['status']}")
|
| 152 |
+
|
| 153 |
+
test_config["status"] = ABTestStatus.RUNNING.value
|
| 154 |
+
test_config["start_time"] = datetime.now().isoformat()
|
| 155 |
+
|
| 156 |
+
await self._save_test_config(test_config)
|
| 157 |
+
|
| 158 |
+
logger.info(f"Started A/B test: {test_name}")
|
| 159 |
+
return True
|
| 160 |
+
|
| 161 |
+
async def allocate_model(
|
| 162 |
+
self,
|
| 163 |
+
test_name: str,
|
| 164 |
+
user_id: Optional[str] = None
|
| 165 |
+
) -> Tuple[str, int]:
|
| 166 |
+
"""
|
| 167 |
+
Allocate a model for a user based on the test configuration.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
test_name: Test name
|
| 171 |
+
user_id: User identifier for consistent allocation
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (model_id, version)
|
| 175 |
+
"""
|
| 176 |
+
test_config = self.active_tests.get(test_name)
|
| 177 |
+
if not test_config:
|
| 178 |
+
test_config = await self._load_test_config(test_name)
|
| 179 |
+
if not test_config:
|
| 180 |
+
raise ValueError(f"Test {test_name} not found")
|
| 181 |
+
|
| 182 |
+
if test_config["status"] != ABTestStatus.RUNNING.value:
|
| 183 |
+
raise ValueError(f"Test {test_name} is not running")
|
| 184 |
+
|
| 185 |
+
# Select model based on allocation strategy
|
| 186 |
+
strategy = TrafficAllocationStrategy(test_config["allocation_strategy"])
|
| 187 |
+
|
| 188 |
+
if strategy == TrafficAllocationStrategy.RANDOM:
|
| 189 |
+
selected = await self._random_allocation(test_config, user_id)
|
| 190 |
+
elif strategy == TrafficAllocationStrategy.WEIGHTED:
|
| 191 |
+
selected = await self._weighted_allocation(test_config)
|
| 192 |
+
elif strategy == TrafficAllocationStrategy.EPSILON_GREEDY:
|
| 193 |
+
selected = await self._epsilon_greedy_allocation(test_config)
|
| 194 |
+
elif strategy == TrafficAllocationStrategy.THOMPSON_SAMPLING:
|
| 195 |
+
selected = await self._thompson_sampling_allocation(test_config)
|
| 196 |
+
else:
|
| 197 |
+
selected = "model_a" # Default fallback
|
| 198 |
+
|
| 199 |
+
# Return model info
|
| 200 |
+
model_info = test_config[selected]
|
| 201 |
+
return (model_info["model_id"], model_info["version"])
|
| 202 |
+
|
| 203 |
+
async def _random_allocation(
|
| 204 |
+
self,
|
| 205 |
+
test_config: Dict[str, Any],
|
| 206 |
+
user_id: Optional[str] = None
|
| 207 |
+
) -> str:
|
| 208 |
+
"""Random allocation with optional user-based consistency."""
|
| 209 |
+
if user_id:
|
| 210 |
+
# Hash user_id for consistent allocation
|
| 211 |
+
hash_val = hash(user_id + test_config["test_id"]) % 100
|
| 212 |
+
threshold = test_config["traffic_split"][0] * 100
|
| 213 |
+
return "model_a" if hash_val < threshold else "model_b"
|
| 214 |
+
else:
|
| 215 |
+
# Pure random
|
| 216 |
+
return "model_a" if random.random() < test_config["traffic_split"][0] else "model_b"
|
| 217 |
+
|
| 218 |
+
async def _weighted_allocation(self, test_config: Dict[str, Any]) -> str:
|
| 219 |
+
"""Weighted allocation based on traffic split."""
|
| 220 |
+
return np.random.choice(
|
| 221 |
+
["model_a", "model_b"],
|
| 222 |
+
p=test_config["traffic_split"]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
async def _epsilon_greedy_allocation(self, test_config: Dict[str, Any]) -> str:
|
| 226 |
+
"""Epsilon-greedy allocation (explore vs exploit)."""
|
| 227 |
+
epsilon = test_config.get("epsilon", 0.1)
|
| 228 |
+
|
| 229 |
+
if random.random() < epsilon:
|
| 230 |
+
# Explore
|
| 231 |
+
return random.choice(["model_a", "model_b"])
|
| 232 |
+
else:
|
| 233 |
+
# Exploit - choose best performing
|
| 234 |
+
results = test_config["results"]
|
| 235 |
+
rate_a = (results["model_a"]["successes"] /
|
| 236 |
+
max(results["model_a"]["predictions"], 1))
|
| 237 |
+
rate_b = (results["model_b"]["successes"] /
|
| 238 |
+
max(results["model_b"]["predictions"], 1))
|
| 239 |
+
|
| 240 |
+
return "model_a" if rate_a >= rate_b else "model_b"
|
| 241 |
+
|
| 242 |
+
async def _thompson_sampling_allocation(self, test_config: Dict[str, Any]) -> str:
|
| 243 |
+
"""Thompson sampling allocation (Bayesian approach)."""
|
| 244 |
+
params = test_config["thompson_params"]
|
| 245 |
+
|
| 246 |
+
# Sample from Beta distributions
|
| 247 |
+
sample_a = np.random.beta(params["model_a"]["alpha"], params["model_a"]["beta"])
|
| 248 |
+
sample_b = np.random.beta(params["model_b"]["alpha"], params["model_b"]["beta"])
|
| 249 |
+
|
| 250 |
+
return "model_a" if sample_a >= sample_b else "model_b"
|
| 251 |
+
|
| 252 |
+
async def record_prediction(
|
| 253 |
+
self,
|
| 254 |
+
test_name: str,
|
| 255 |
+
model_selection: str, # "model_a" or "model_b"
|
| 256 |
+
success: bool,
|
| 257 |
+
prediction_metadata: Optional[Dict[str, Any]] = None
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Record a prediction result for the test.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
test_name: Test name
|
| 264 |
+
model_selection: Which model was used
|
| 265 |
+
success: Whether prediction was successful
|
| 266 |
+
prediction_metadata: Additional metadata
|
| 267 |
+
"""
|
| 268 |
+
test_config = self.active_tests.get(test_name)
|
| 269 |
+
if not test_config:
|
| 270 |
+
test_config = await self._load_test_config(test_name)
|
| 271 |
+
if not test_config:
|
| 272 |
+
raise ValueError(f"Test {test_name} not found")
|
| 273 |
+
|
| 274 |
+
# Update results
|
| 275 |
+
results = test_config["results"][model_selection]
|
| 276 |
+
results["predictions"] += 1
|
| 277 |
+
if success:
|
| 278 |
+
results["successes"] += 1
|
| 279 |
+
|
| 280 |
+
# Update Thompson sampling parameters if applicable
|
| 281 |
+
if test_config["allocation_strategy"] == TrafficAllocationStrategy.THOMPSON_SAMPLING.value:
|
| 282 |
+
params = test_config["thompson_params"][model_selection]
|
| 283 |
+
if success:
|
| 284 |
+
params["alpha"] += 1
|
| 285 |
+
else:
|
| 286 |
+
params["beta"] += 1
|
| 287 |
+
|
| 288 |
+
# Save updated config
|
| 289 |
+
await self._save_test_config(test_config)
|
| 290 |
+
|
| 291 |
+
# Check if we should analyze results
|
| 292 |
+
total_predictions = (test_config["results"]["model_a"]["predictions"] +
|
| 293 |
+
test_config["results"]["model_b"]["predictions"])
|
| 294 |
+
|
| 295 |
+
if total_predictions >= test_config["minimum_sample_size"]:
|
| 296 |
+
analysis = await self.analyze_test(test_name)
|
| 297 |
+
|
| 298 |
+
if test_config["auto_stop"] and analysis.get("winner"):
|
| 299 |
+
await self.stop_test(test_name, reason="Winner found")
|
| 300 |
+
|
| 301 |
+
async def analyze_test(self, test_name: str) -> Dict[str, Any]:
|
| 302 |
+
"""
|
| 303 |
+
Analyze test results for statistical significance.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Analysis results including winner if found
|
| 307 |
+
"""
|
| 308 |
+
test_config = self.active_tests.get(test_name)
|
| 309 |
+
if not test_config:
|
| 310 |
+
test_config = await self._load_test_config(test_name)
|
| 311 |
+
if not test_config:
|
| 312 |
+
raise ValueError(f"Test {test_name} not found")
|
| 313 |
+
|
| 314 |
+
results_a = test_config["results"]["model_a"]
|
| 315 |
+
results_b = test_config["results"]["model_b"]
|
| 316 |
+
|
| 317 |
+
# Calculate conversion rates
|
| 318 |
+
rate_a = results_a["successes"] / max(results_a["predictions"], 1)
|
| 319 |
+
rate_b = results_b["successes"] / max(results_b["predictions"], 1)
|
| 320 |
+
|
| 321 |
+
# Perform chi-square test
|
| 322 |
+
contingency_table = np.array([
|
| 323 |
+
[results_a["successes"], results_a["predictions"] - results_a["successes"]],
|
| 324 |
+
[results_b["successes"], results_b["predictions"] - results_b["successes"]]
|
| 325 |
+
])
|
| 326 |
+
|
| 327 |
+
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
|
| 328 |
+
|
| 329 |
+
# Calculate confidence intervals
|
| 330 |
+
ci_a = self._calculate_confidence_interval(
|
| 331 |
+
results_a["successes"], results_a["predictions"]
|
| 332 |
+
)
|
| 333 |
+
ci_b = self._calculate_confidence_interval(
|
| 334 |
+
results_b["successes"], results_b["predictions"]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Determine winner
|
| 338 |
+
winner = None
|
| 339 |
+
if p_value < test_config["significance_level"]:
|
| 340 |
+
winner = "model_a" if rate_a > rate_b else "model_b"
|
| 341 |
+
|
| 342 |
+
# Calculate lift
|
| 343 |
+
lift = ((rate_b - rate_a) / rate_a * 100) if rate_a > 0 else 0
|
| 344 |
+
|
| 345 |
+
analysis = {
|
| 346 |
+
"model_a": {
|
| 347 |
+
"conversion_rate": rate_a,
|
| 348 |
+
"confidence_interval": ci_a,
|
| 349 |
+
"sample_size": results_a["predictions"]
|
| 350 |
+
},
|
| 351 |
+
"model_b": {
|
| 352 |
+
"conversion_rate": rate_b,
|
| 353 |
+
"confidence_interval": ci_b,
|
| 354 |
+
"sample_size": results_b["predictions"]
|
| 355 |
+
},
|
| 356 |
+
"p_value": p_value,
|
| 357 |
+
"chi_square": chi2,
|
| 358 |
+
"significant": p_value < test_config["significance_level"],
|
| 359 |
+
"winner": winner,
|
| 360 |
+
"lift": lift,
|
| 361 |
+
"analysis_time": datetime.now().isoformat()
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# Update test config with latest analysis
|
| 365 |
+
test_config["latest_analysis"] = analysis
|
| 366 |
+
await self._save_test_config(test_config)
|
| 367 |
+
|
| 368 |
+
return analysis
|
| 369 |
+
|
| 370 |
+
def _calculate_confidence_interval(
|
| 371 |
+
self,
|
| 372 |
+
successes: int,
|
| 373 |
+
total: int,
|
| 374 |
+
confidence_level: float = 0.95
|
| 375 |
+
) -> Tuple[float, float]:
|
| 376 |
+
"""Calculate confidence interval for conversion rate."""
|
| 377 |
+
if total == 0:
|
| 378 |
+
return (0.0, 0.0)
|
| 379 |
+
|
| 380 |
+
rate = successes / total
|
| 381 |
+
z = stats.norm.ppf((1 + confidence_level) / 2)
|
| 382 |
+
|
| 383 |
+
# Wilson score interval
|
| 384 |
+
denominator = 1 + z**2 / total
|
| 385 |
+
center = (rate + z**2 / (2 * total)) / denominator
|
| 386 |
+
margin = z * np.sqrt(rate * (1 - rate) / total + z**2 / (4 * total**2)) / denominator
|
| 387 |
+
|
| 388 |
+
return (max(0, center - margin), min(1, center + margin))
|
| 389 |
+
|
| 390 |
+
async def stop_test(self, test_name: str, reason: str = "Manual stop") -> bool:
|
| 391 |
+
"""Stop an A/B test."""
|
| 392 |
+
test_config = self.active_tests.get(test_name)
|
| 393 |
+
if not test_config:
|
| 394 |
+
test_config = await self._load_test_config(test_name)
|
| 395 |
+
if not test_config:
|
| 396 |
+
raise ValueError(f"Test {test_name} not found")
|
| 397 |
+
|
| 398 |
+
test_config["status"] = ABTestStatus.STOPPED.value
|
| 399 |
+
test_config["end_time"] = datetime.now().isoformat()
|
| 400 |
+
test_config["stop_reason"] = reason
|
| 401 |
+
|
| 402 |
+
# Perform final analysis
|
| 403 |
+
final_analysis = await self.analyze_test(test_name)
|
| 404 |
+
test_config["final_analysis"] = final_analysis
|
| 405 |
+
|
| 406 |
+
await self._save_test_config(test_config)
|
| 407 |
+
|
| 408 |
+
# Move to completed tests
|
| 409 |
+
self.test_results[test_name] = test_config
|
| 410 |
+
if test_name in self.active_tests:
|
| 411 |
+
del self.active_tests[test_name]
|
| 412 |
+
|
| 413 |
+
logger.info(f"Stopped A/B test {test_name}: {reason}")
|
| 414 |
+
return True
|
| 415 |
+
|
| 416 |
+
async def get_test_status(self, test_name: str) -> Dict[str, Any]:
|
| 417 |
+
"""Get current status of a test."""
|
| 418 |
+
test_config = self.active_tests.get(test_name)
|
| 419 |
+
if not test_config:
|
| 420 |
+
test_config = await self._load_test_config(test_name)
|
| 421 |
+
if not test_config:
|
| 422 |
+
raise ValueError(f"Test {test_name} not found")
|
| 423 |
+
|
| 424 |
+
# Add runtime if running
|
| 425 |
+
if test_config["status"] == ABTestStatus.RUNNING.value and test_config["start_time"]:
|
| 426 |
+
start = datetime.fromisoformat(test_config["start_time"])
|
| 427 |
+
runtime = (datetime.now() - start).total_seconds() / 3600
|
| 428 |
+
test_config["runtime_hours"] = runtime
|
| 429 |
+
|
| 430 |
+
# Check if should auto-stop due to duration
|
| 431 |
+
if test_config.get("duration_hours") and runtime >= test_config["duration_hours"]:
|
| 432 |
+
await self.stop_test(test_name, reason="Duration limit reached")
|
| 433 |
+
|
| 434 |
+
return test_config
|
| 435 |
+
|
| 436 |
+
async def promote_winner(self, test_name: str) -> bool:
|
| 437 |
+
"""Promote the winning model to production."""
|
| 438 |
+
test_config = self.test_results.get(test_name)
|
| 439 |
+
if not test_config:
|
| 440 |
+
# Try loading completed test
|
| 441 |
+
test_config = await self._load_test_config(test_name)
|
| 442 |
+
if not test_config or test_config["status"] != ABTestStatus.STOPPED.value:
|
| 443 |
+
raise ValueError(f"Test {test_name} not completed")
|
| 444 |
+
|
| 445 |
+
final_analysis = test_config.get("final_analysis", {})
|
| 446 |
+
winner = final_analysis.get("winner")
|
| 447 |
+
|
| 448 |
+
if not winner:
|
| 449 |
+
raise ValueError(f"No winner found for test {test_name}")
|
| 450 |
+
|
| 451 |
+
# Promote winning model
|
| 452 |
+
model_info = test_config[winner]
|
| 453 |
+
success = await training_pipeline.promote_model(
|
| 454 |
+
model_info["model_id"],
|
| 455 |
+
model_info["version"],
|
| 456 |
+
"production"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
if success:
|
| 460 |
+
logger.info(f"Promoted {winner} from test {test_name} to production")
|
| 461 |
+
|
| 462 |
+
return success
|
| 463 |
+
|
| 464 |
+
async def _save_test_config(self, test_config: Dict[str, Any]):
|
| 465 |
+
"""Save test configuration to Redis."""
|
| 466 |
+
redis_client = await get_redis_client()
|
| 467 |
+
key = f"ab_test:{test_config['test_name']}"
|
| 468 |
+
await redis_client.set(
|
| 469 |
+
key,
|
| 470 |
+
json.dumps(test_config),
|
| 471 |
+
ex=86400 * 90 # 90 days
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
async def _load_test_config(self, test_name: str) -> Optional[Dict[str, Any]]:
|
| 475 |
+
"""Load test configuration from Redis."""
|
| 476 |
+
redis_client = await get_redis_client()
|
| 477 |
+
key = f"ab_test:{test_name}"
|
| 478 |
+
data = await redis_client.get(key)
|
| 479 |
+
return json.loads(data) if data else None
|
| 480 |
+
|
| 481 |
+
async def list_active_tests(self) -> List[Dict[str, Any]]:
|
| 482 |
+
"""List all active tests."""
|
| 483 |
+
# Load from Redis pattern
|
| 484 |
+
redis_client = await get_redis_client()
|
| 485 |
+
keys = await redis_client.keys("ab_test:*")
|
| 486 |
+
|
| 487 |
+
active_tests = []
|
| 488 |
+
for key in keys:
|
| 489 |
+
data = await redis_client.get(key)
|
| 490 |
+
if data:
|
| 491 |
+
test_config = json.loads(data)
|
| 492 |
+
if test_config["status"] in [ABTestStatus.RUNNING.value, ABTestStatus.PAUSED.value]:
|
| 493 |
+
active_tests.append({
|
| 494 |
+
"test_name": test_config["test_name"],
|
| 495 |
+
"status": test_config["status"],
|
| 496 |
+
"model_a": test_config["model_a"]["model_id"],
|
| 497 |
+
"model_b": test_config["model_b"]["model_id"],
|
| 498 |
+
"start_time": test_config.get("start_time"),
|
| 499 |
+
"predictions": (test_config["results"]["model_a"]["predictions"] +
|
| 500 |
+
test_config["results"]["model_b"]["predictions"])
|
| 501 |
+
})
|
| 502 |
+
|
| 503 |
+
return active_tests
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
# Global A/B testing framework instance
|
| 507 |
+
ab_testing = ABTestFramework()
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
async def get_ab_testing() -> ABTestFramework:
|
| 511 |
+
"""Get the global A/B testing framework instance."""
|
| 512 |
+
return ab_testing
|
src/ml/training_pipeline.py
CHANGED
|
@@ -1,813 +1,523 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
| 8 |
import os
|
| 9 |
-
from
|
| 10 |
-
import
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
from torch.utils.data import Dataset, DataLoader
|
| 13 |
-
from torch.optim import AdamW
|
| 14 |
-
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 15 |
-
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
| 16 |
-
from typing import Dict, List, Optional, Tuple, Any
|
| 17 |
-
import pandas as pd
|
| 18 |
-
import numpy as np
|
| 19 |
from pathlib import Path
|
| 20 |
-
import
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
import
|
| 24 |
-
from sklearn.metrics import
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
-
|
| 34 |
-
class TrainingConfig:
|
| 35 |
-
"""Configuração de treinamento"""
|
| 36 |
-
|
| 37 |
-
# Hiperparâmetros principais
|
| 38 |
-
learning_rate: float = 2e-5
|
| 39 |
-
batch_size: int = 8
|
| 40 |
-
num_epochs: int = 10
|
| 41 |
-
warmup_steps: int = 1000
|
| 42 |
-
max_grad_norm: float = 1.0
|
| 43 |
-
weight_decay: float = 0.01
|
| 44 |
-
|
| 45 |
-
# Configurações de dados
|
| 46 |
-
max_sequence_length: int = 512
|
| 47 |
-
train_split: float = 0.8
|
| 48 |
-
val_split: float = 0.1
|
| 49 |
-
test_split: float = 0.1
|
| 50 |
-
|
| 51 |
-
# Configurações do modelo
|
| 52 |
-
model_size: str = "medium"
|
| 53 |
-
specialized_tasks: List[str] = None
|
| 54 |
-
use_mixed_precision: bool = True
|
| 55 |
-
gradient_accumulation_steps: int = 4
|
| 56 |
-
|
| 57 |
-
# Configurações de checkpoint
|
| 58 |
-
save_strategy: str = "epoch" # "steps" ou "epoch"
|
| 59 |
-
save_steps: int = 500
|
| 60 |
-
eval_steps: int = 100
|
| 61 |
-
logging_steps: int = 50
|
| 62 |
-
output_dir: str = "./models/cidadao-gpt"
|
| 63 |
-
|
| 64 |
-
# Configurações de avaliação
|
| 65 |
-
eval_strategy: str = "steps"
|
| 66 |
-
metric_for_best_model: str = "eval_f1"
|
| 67 |
-
greater_is_better: bool = True
|
| 68 |
-
early_stopping_patience: int = 3
|
| 69 |
-
|
| 70 |
-
# Configurações de experimentação
|
| 71 |
-
experiment_name: str = "cidadao-gpt-v1"
|
| 72 |
-
use_wandb: bool = True
|
| 73 |
-
wandb_project: str = "cidadao-ai"
|
| 74 |
-
|
| 75 |
-
def __post_init__(self):
|
| 76 |
-
if self.specialized_tasks is None:
|
| 77 |
-
self.specialized_tasks = ["all"]
|
| 78 |
|
| 79 |
|
| 80 |
-
class
|
| 81 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
def __init__(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
self.
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# Preparar vocabulário especializado
|
| 98 |
-
self._prepare_specialized_vocab()
|
| 99 |
-
|
| 100 |
-
def _load_data(self, data_path: str) -> List[Dict]:
|
| 101 |
-
"""Carregar dados de transparência"""
|
| 102 |
-
|
| 103 |
-
data_file = Path(data_path)
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
"text": "Contrato para aquisição de equipamentos médicos no valor de R$ 2.500.000,00 firmado entre Ministério da Saúde e Empresa XYZ LTDA. Processo licitatório 12345/2024, modalidade pregão eletrônico.",
|
| 131 |
-
"anomaly_label": 0, # Normal
|
| 132 |
-
"financial_risk": 2, # Médio
|
| 133 |
-
"legal_compliance": 1, # Conforme
|
| 134 |
-
"contract_value": 2500000.0,
|
| 135 |
-
"entity_types": [1, 2, 3], # Ministério, Empresa, Equipamento
|
| 136 |
-
"corruption_indicators": []
|
| 137 |
-
},
|
| 138 |
-
{
|
| 139 |
-
"text": "Contrato emergencial sem licitação para fornecimento de insumos hospitalares. Valor: R$ 15.000.000,00. Empresa beneficiária: Alpha Beta Comercial S.A., CNPJ com irregularidades na Receita Federal.",
|
| 140 |
-
"anomaly_label": 2, # Anômalo
|
| 141 |
-
"financial_risk": 4, # Alto
|
| 142 |
-
"legal_compliance": 0, # Não conforme
|
| 143 |
-
"contract_value": 15000000.0,
|
| 144 |
-
"entity_types": [1, 2, 4], # Ministério, Empresa, Insumos
|
| 145 |
-
"corruption_indicators": [1, 3, 5] # Emergencial, Sem licitação, CNPJ irregular
|
| 146 |
-
}
|
| 147 |
-
]
|
| 148 |
-
|
| 149 |
-
# Amplificar dados com variações
|
| 150 |
-
for base_example in contract_examples:
|
| 151 |
-
for i in range(50): # 50 variações de cada exemplo
|
| 152 |
-
example = base_example.copy()
|
| 153 |
-
example["id"] = f"{len(sample_data)}"
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
def
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
text + " Processo administrativo arquivado em sistema SIASG.",
|
| 171 |
-
text + " Valor atualizado conforme INPC/IBGE."
|
| 172 |
-
]
|
| 173 |
|
| 174 |
-
return
|
| 175 |
|
| 176 |
-
def
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
"
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
"valor", "preco", "orcamento", "pagamento", "repasse", "empenho",
|
| 189 |
-
|
| 190 |
-
# Termos jurídicos
|
| 191 |
-
"conformidade", "irregularidade", "infração", "penalidade", "multa",
|
| 192 |
-
|
| 193 |
-
# Indicadores de corrupção
|
| 194 |
-
"superfaturamento", "direcionamento", "cartel", "fraude", "peculato"
|
| 195 |
}
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def __len__(self) -> int:
|
| 202 |
-
return len(self.data)
|
| 203 |
-
|
| 204 |
-
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 205 |
-
item = self.data[idx]
|
| 206 |
-
|
| 207 |
-
# Tokenizar texto
|
| 208 |
-
encoding = self.tokenizer(
|
| 209 |
-
item["text"],
|
| 210 |
-
truncation=True,
|
| 211 |
-
padding="max_length",
|
| 212 |
-
max_length=self.max_length,
|
| 213 |
-
return_tensors="pt"
|
| 214 |
-
)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
}
|
| 221 |
|
| 222 |
-
#
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
if "financial_risk" in item:
|
| 227 |
-
result["financial_risk_labels"] = torch.tensor(item["financial_risk"], dtype=torch.long)
|
| 228 |
-
|
| 229 |
-
if "legal_compliance" in item:
|
| 230 |
-
result["legal_compliance_labels"] = torch.tensor(item["legal_compliance"], dtype=torch.long)
|
| 231 |
|
| 232 |
-
|
| 233 |
-
if "entity_types" in item:
|
| 234 |
-
entity_types = torch.zeros(self.max_length, dtype=torch.long)
|
| 235 |
-
for i, entity_type in enumerate(item["entity_types"][:self.max_length]):
|
| 236 |
-
entity_types[i] = entity_type
|
| 237 |
-
result["entity_types"] = entity_types
|
| 238 |
-
|
| 239 |
-
if "corruption_indicators" in item:
|
| 240 |
-
corruption_indicators = torch.zeros(self.max_length, dtype=torch.long)
|
| 241 |
-
for i, indicator in enumerate(item["corruption_indicators"][:self.max_length]):
|
| 242 |
-
corruption_indicators[i] = indicator
|
| 243 |
-
result["corruption_indicators"] = corruption_indicators
|
| 244 |
-
|
| 245 |
-
return result
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
class CidadaoTrainer:
|
| 249 |
-
"""Trainer especializado para Cidadão.AI"""
|
| 250 |
|
| 251 |
-
def
|
| 252 |
self,
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
#
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
|
| 266 |
-
self.optimizer = AdamW(
|
| 267 |
-
self.model.parameters(),
|
| 268 |
-
lr=config.learning_rate,
|
| 269 |
-
weight_decay=config.weight_decay
|
| 270 |
-
)
|
| 271 |
|
| 272 |
-
#
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
"
|
| 278 |
-
"eval_loss": [],
|
| 279 |
-
"eval_metrics": []
|
| 280 |
}
|
| 281 |
-
|
| 282 |
-
# Early stopping
|
| 283 |
-
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
|
| 284 |
-
self.patience_counter = 0
|
| 285 |
-
|
| 286 |
-
# Configurar logging
|
| 287 |
-
if config.use_wandb:
|
| 288 |
-
wandb.init(
|
| 289 |
-
project=config.wandb_project,
|
| 290 |
-
name=config.experiment_name,
|
| 291 |
-
config=asdict(config)
|
| 292 |
-
)
|
| 293 |
|
| 294 |
-
def
|
| 295 |
self,
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
if
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
)
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
-
#
|
| 322 |
-
|
| 323 |
-
self.scheduler = get_linear_schedule_with_warmup(
|
| 324 |
-
self.optimizer,
|
| 325 |
-
num_warmup_steps=self.config.warmup_steps,
|
| 326 |
-
num_training_steps=total_steps
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
# Loop de treinamento
|
| 330 |
-
global_step = 0
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
self.
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
break
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
# Avaliação final
|
| 365 |
-
if test_dataset:
|
| 366 |
-
test_loader = DataLoader(
|
| 367 |
-
test_dataset,
|
| 368 |
-
batch_size=self.config.batch_size,
|
| 369 |
-
shuffle=False,
|
| 370 |
-
num_workers=4
|
| 371 |
)
|
| 372 |
|
| 373 |
-
logger.info("
|
| 374 |
-
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
# Finalizar treinamento
|
| 381 |
-
self._finalize_training()
|
| 382 |
|
| 383 |
-
def
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
| 408 |
-
self.scaler.step(self.optimizer)
|
| 409 |
-
self.scaler.update()
|
| 410 |
-
self.scheduler.step()
|
| 411 |
-
self.optimizer.zero_grad()
|
| 412 |
-
else:
|
| 413 |
-
loss.backward()
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
total_loss += loss.item()
|
| 422 |
-
|
| 423 |
-
# Logging
|
| 424 |
-
if step % self.config.logging_steps == 0:
|
| 425 |
-
avg_loss = total_loss / (step + 1)
|
| 426 |
-
progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
|
| 427 |
|
| 428 |
-
if
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
"train/epoch": epoch,
|
| 433 |
-
"train/step": global_step + step
|
| 434 |
-
})
|
| 435 |
-
|
| 436 |
-
return total_loss / len(train_loader)
|
| 437 |
-
|
| 438 |
-
def _compute_multi_task_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 439 |
-
"""Computar loss multi-tarefa"""
|
| 440 |
-
|
| 441 |
-
total_loss = 0.0
|
| 442 |
-
loss_weights = {
|
| 443 |
-
"anomaly": 1.0,
|
| 444 |
-
"financial": 0.8,
|
| 445 |
-
"legal": 0.6
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
# Loss de detecção de anomalias
|
| 449 |
-
if "anomaly_labels" in batch:
|
| 450 |
-
anomaly_outputs = self.model.detect_anomalies(
|
| 451 |
-
input_ids=batch["input_ids"],
|
| 452 |
-
attention_mask=batch["attention_mask"],
|
| 453 |
-
entity_types=batch.get("entity_types"),
|
| 454 |
-
corruption_indicators=batch.get("corruption_indicators")
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
# Extrair logits dos resultados
|
| 458 |
-
anomaly_logits = []
|
| 459 |
-
for pred in anomaly_outputs["predictions"]:
|
| 460 |
-
probs = [
|
| 461 |
-
pred["probabilities"]["normal"],
|
| 462 |
-
pred["probabilities"]["suspicious"],
|
| 463 |
-
pred["probabilities"]["anomalous"]
|
| 464 |
-
]
|
| 465 |
-
anomaly_logits.append(probs)
|
| 466 |
-
|
| 467 |
-
anomaly_logits = torch.tensor(anomaly_logits, device=self.device)
|
| 468 |
-
anomaly_loss = nn.CrossEntropyLoss()(anomaly_logits, batch["anomaly_labels"])
|
| 469 |
-
total_loss += loss_weights["anomaly"] * anomaly_loss
|
| 470 |
-
|
| 471 |
-
# Loss de análise financeira
|
| 472 |
-
if "financial_risk_labels" in batch:
|
| 473 |
-
financial_outputs = self.model.analyze_financial_risk(
|
| 474 |
-
input_ids=batch["input_ids"],
|
| 475 |
-
attention_mask=batch["attention_mask"]
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
# Extrair logits dos resultados
|
| 479 |
-
risk_logits = []
|
| 480 |
-
for pred in financial_outputs["predictions"]:
|
| 481 |
-
probs = list(pred["risk_probabilities"].values())
|
| 482 |
-
risk_logits.append(probs)
|
| 483 |
-
|
| 484 |
-
risk_logits = torch.tensor(risk_logits, device=self.device)
|
| 485 |
-
financial_loss = nn.CrossEntropyLoss()(risk_logits, batch["financial_risk_labels"])
|
| 486 |
-
total_loss += loss_weights["financial"] * financial_loss
|
| 487 |
-
|
| 488 |
-
# Loss de conformidade legal
|
| 489 |
-
if "legal_compliance_labels" in batch:
|
| 490 |
-
legal_outputs = self.model.check_legal_compliance(
|
| 491 |
-
input_ids=batch["input_ids"],
|
| 492 |
-
attention_mask=batch["attention_mask"]
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
# Extrair logits dos resultados
|
| 496 |
-
compliance_logits = []
|
| 497 |
-
for pred in legal_outputs["predictions"]:
|
| 498 |
-
probs = [
|
| 499 |
-
pred["legal_analysis"]["non_compliant_prob"],
|
| 500 |
-
pred["legal_analysis"]["compliant_prob"]
|
| 501 |
-
]
|
| 502 |
-
compliance_logits.append(probs)
|
| 503 |
-
|
| 504 |
-
compliance_logits = torch.tensor(compliance_logits, device=self.device)
|
| 505 |
-
legal_loss = nn.CrossEntropyLoss()(compliance_logits, batch["legal_compliance_labels"])
|
| 506 |
-
total_loss += loss_weights["legal"] * legal_loss
|
| 507 |
-
|
| 508 |
-
return total_loss
|
| 509 |
-
|
| 510 |
-
def _evaluate(self, eval_loader: DataLoader, epoch: int, is_test: bool = False) -> Dict[str, float]:
|
| 511 |
-
"""Avaliar modelo"""
|
| 512 |
-
|
| 513 |
-
self.model.eval()
|
| 514 |
-
total_loss = 0.0
|
| 515 |
-
|
| 516 |
-
# Coletar predições e labels
|
| 517 |
-
all_predictions = {
|
| 518 |
-
"anomaly": {"preds": [], "labels": []},
|
| 519 |
-
"financial": {"preds": [], "labels": []},
|
| 520 |
-
"legal": {"preds": [], "labels": []}
|
| 521 |
-
}
|
| 522 |
-
|
| 523 |
-
with torch.no_grad():
|
| 524 |
-
for batch in tqdm(eval_loader, desc="Avaliação"):
|
| 525 |
-
batch = {k: v.to(self.device) for k, v in batch.items()}
|
| 526 |
|
| 527 |
-
#
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
wandb.log(log_metrics)
|
| 554 |
-
|
| 555 |
-
return metrics
|
| 556 |
-
|
| 557 |
-
def _collect_predictions(self, batch: Dict[str, torch.Tensor], all_predictions: Dict):
|
| 558 |
-
"""Coletar predições para avaliação"""
|
| 559 |
-
|
| 560 |
-
# Anomaly detection
|
| 561 |
-
if "anomaly_labels" in batch:
|
| 562 |
-
anomaly_outputs = self.model.detect_anomalies(
|
| 563 |
-
input_ids=batch["input_ids"],
|
| 564 |
-
attention_mask=batch["attention_mask"]
|
| 565 |
-
)
|
| 566 |
-
|
| 567 |
-
for i, pred in enumerate(anomaly_outputs["predictions"]):
|
| 568 |
-
anomaly_type_map = {"Normal": 0, "Suspeito": 1, "Anômalo": 2}
|
| 569 |
-
pred_label = anomaly_type_map[pred["anomaly_type"]]
|
| 570 |
-
all_predictions["anomaly"]["preds"].append(pred_label)
|
| 571 |
-
all_predictions["anomaly"]["labels"].append(batch["anomaly_labels"][i].item())
|
| 572 |
-
|
| 573 |
-
# Financial analysis
|
| 574 |
-
if "financial_risk_labels" in batch:
|
| 575 |
-
financial_outputs = self.model.analyze_financial_risk(
|
| 576 |
-
input_ids=batch["input_ids"],
|
| 577 |
-
attention_mask=batch["attention_mask"]
|
| 578 |
-
)
|
| 579 |
-
|
| 580 |
-
for i, pred in enumerate(financial_outputs["predictions"]):
|
| 581 |
-
risk_level_map = {"Muito Baixo": 0, "Baixo": 1, "Médio": 2, "Alto": 3, "Muito Alto": 4}
|
| 582 |
-
pred_label = risk_level_map[pred["risk_level"]]
|
| 583 |
-
all_predictions["financial"]["preds"].append(pred_label)
|
| 584 |
-
all_predictions["financial"]["labels"].append(batch["financial_risk_labels"][i].item())
|
| 585 |
-
|
| 586 |
-
# Legal compliance
|
| 587 |
-
if "legal_compliance_labels" in batch:
|
| 588 |
-
legal_outputs = self.model.check_legal_compliance(
|
| 589 |
-
input_ids=batch["input_ids"],
|
| 590 |
-
attention_mask=batch["attention_mask"]
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
for i, pred in enumerate(legal_outputs["predictions"]):
|
| 594 |
-
pred_label = 1 if pred["is_compliant"] else 0
|
| 595 |
-
all_predictions["legal"]["preds"].append(pred_label)
|
| 596 |
-
all_predictions["legal"]["labels"].append(batch["legal_compliance_labels"][i].item())
|
| 597 |
-
|
| 598 |
-
def _compute_task_metrics(self, predictions: List, labels: List, task_name: str) -> Dict[str, float]:
|
| 599 |
-
"""Computar métricas para uma tarefa específica"""
|
| 600 |
-
|
| 601 |
-
accuracy = accuracy_score(labels, predictions)
|
| 602 |
-
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 603 |
-
labels, predictions, average='weighted'
|
| 604 |
-
)
|
| 605 |
-
|
| 606 |
-
metrics = {
|
| 607 |
-
f"eval_{task_name}_accuracy": accuracy,
|
| 608 |
-
f"eval_{task_name}_precision": precision,
|
| 609 |
-
f"eval_{task_name}_recall": recall,
|
| 610 |
-
f"eval_{task_name}_f1": f1
|
| 611 |
-
}
|
| 612 |
-
|
| 613 |
-
# Métrica composta para early stopping
|
| 614 |
-
if task_name == "anomaly": # Usar anomaly como principal
|
| 615 |
-
metrics["eval_f1"] = f1
|
| 616 |
-
|
| 617 |
-
return metrics
|
| 618 |
-
|
| 619 |
-
def _is_better_metric(self, current_metric: float) -> bool:
|
| 620 |
-
"""Verificar se métrica atual é melhor"""
|
| 621 |
-
if self.config.greater_is_better:
|
| 622 |
-
return current_metric > self.best_metric
|
| 623 |
-
else:
|
| 624 |
-
return current_metric < self.best_metric
|
| 625 |
-
|
| 626 |
-
def _save_checkpoint(self, epoch: int, is_best: bool = False):
|
| 627 |
-
"""Salvar checkpoint do modelo"""
|
| 628 |
-
|
| 629 |
-
output_dir = Path(self.config.output_dir)
|
| 630 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 631 |
-
|
| 632 |
-
if is_best:
|
| 633 |
-
save_path = output_dir / "best_model"
|
| 634 |
-
else:
|
| 635 |
-
save_path = output_dir / f"checkpoint-epoch-{epoch}"
|
| 636 |
-
|
| 637 |
-
# Salvar modelo
|
| 638 |
-
self.model.save_model(str(save_path))
|
| 639 |
-
|
| 640 |
-
# Salvar estado do treinamento
|
| 641 |
-
training_state = {
|
| 642 |
-
"epoch": epoch,
|
| 643 |
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 644 |
-
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 645 |
-
"best_metric": self.best_metric,
|
| 646 |
-
"training_history": self.training_history
|
| 647 |
-
}
|
| 648 |
-
|
| 649 |
-
torch.save(training_state, save_path / "training_state.pt")
|
| 650 |
-
|
| 651 |
-
logger.info(f"✅ Checkpoint salvo em {save_path}")
|
| 652 |
-
|
| 653 |
-
def _finalize_training(self):
|
| 654 |
-
"""Finalizar treinamento"""
|
| 655 |
-
|
| 656 |
-
# Salvar histórico de treinamento
|
| 657 |
-
output_dir = Path(self.config.output_dir)
|
| 658 |
-
|
| 659 |
-
with open(output_dir / "training_history.json", "w") as f:
|
| 660 |
-
json_utils.dump(self.training_history, f, indent=2)
|
| 661 |
-
|
| 662 |
-
# Plotar curvas de treinamento
|
| 663 |
-
self._plot_training_curves()
|
| 664 |
-
|
| 665 |
-
if self.config.use_wandb:
|
| 666 |
-
wandb.finish()
|
| 667 |
-
|
| 668 |
-
logger.info("🎉 Treinamento finalizado com sucesso!")
|
| 669 |
-
|
| 670 |
-
def _plot_training_curves(self):
|
| 671 |
-
"""Plotar curvas de treinamento"""
|
| 672 |
-
|
| 673 |
-
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
| 674 |
-
|
| 675 |
-
# Loss de treinamento
|
| 676 |
-
epochs = range(1, len(self.training_history["train_loss"]) + 1)
|
| 677 |
-
axes[0, 0].plot(epochs, self.training_history["train_loss"])
|
| 678 |
-
axes[0, 0].set_title("Loss de Treinamento")
|
| 679 |
-
axes[0, 0].set_xlabel("Época")
|
| 680 |
-
axes[0, 0].set_ylabel("Loss")
|
| 681 |
-
|
| 682 |
-
# Métricas de avaliação
|
| 683 |
-
if self.training_history["eval_metrics"]:
|
| 684 |
-
eval_epochs = range(1, len(self.training_history["eval_metrics"]) + 1)
|
| 685 |
-
|
| 686 |
-
# F1 Score
|
| 687 |
-
f1_scores = [m.get("eval_f1", 0) for m in self.training_history["eval_metrics"]]
|
| 688 |
-
axes[0, 1].plot(eval_epochs, f1_scores, 'g-')
|
| 689 |
-
axes[0, 1].set_title("F1 Score")
|
| 690 |
-
axes[0, 1].set_xlabel("Época")
|
| 691 |
-
axes[0, 1].set_ylabel("F1")
|
| 692 |
-
|
| 693 |
-
# Accuracy
|
| 694 |
-
accuracy_scores = [m.get("eval_anomaly_accuracy", 0) for m in self.training_history["eval_metrics"]]
|
| 695 |
-
axes[1, 0].plot(eval_epochs, accuracy_scores, 'b-')
|
| 696 |
-
axes[1, 0].set_title("Accuracy")
|
| 697 |
-
axes[1, 0].set_xlabel("Época")
|
| 698 |
-
axes[1, 0].set_ylabel("Accuracy")
|
| 699 |
-
|
| 700 |
-
# Loss de avaliação
|
| 701 |
-
eval_losses = [m.get("eval_loss", 0) for m in self.training_history["eval_metrics"]]
|
| 702 |
-
axes[1, 1].plot(eval_epochs, eval_losses, 'r-')
|
| 703 |
-
axes[1, 1].set_title("Loss de Avaliação")
|
| 704 |
-
axes[1, 1].set_xlabel("Época")
|
| 705 |
-
axes[1, 1].set_ylabel("Loss")
|
| 706 |
-
|
| 707 |
-
plt.tight_layout()
|
| 708 |
-
|
| 709 |
-
# Salvar plot
|
| 710 |
-
output_dir = Path(self.config.output_dir)
|
| 711 |
-
plt.savefig(output_dir / "training_curves.png", dpi=300, bbox_inches='tight')
|
| 712 |
-
plt.close()
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
def create_training_pipeline(
|
| 716 |
-
data_path: str,
|
| 717 |
-
config: Optional[TrainingConfig] = None
|
| 718 |
-
) -> Tuple[CidadaoAIForTransparency, CidadaoTrainer]:
|
| 719 |
-
"""
|
| 720 |
-
Criar pipeline de treinamento completo
|
| 721 |
-
|
| 722 |
-
Args:
|
| 723 |
-
data_path: Caminho para dados de treinamento
|
| 724 |
-
config: Configuração de treinamento
|
| 725 |
-
|
| 726 |
-
Returns:
|
| 727 |
-
Tuple com modelo e trainer
|
| 728 |
-
"""
|
| 729 |
-
|
| 730 |
-
if config is None:
|
| 731 |
-
config = TrainingConfig()
|
| 732 |
-
|
| 733 |
-
logger.info("🏗️ Criando pipeline de treinamento Cidadão.AI")
|
| 734 |
-
|
| 735 |
-
# Criar modelo
|
| 736 |
-
model = create_cidadao_model(
|
| 737 |
-
specialized_tasks=config.specialized_tasks,
|
| 738 |
-
model_size=config.model_size
|
| 739 |
-
)
|
| 740 |
-
|
| 741 |
-
# Criar tokenizer
|
| 742 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
| 743 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 744 |
-
|
| 745 |
-
# Redimensionar embeddings se necessário
|
| 746 |
-
model.model.model.resize_token_embeddings(len(tokenizer))
|
| 747 |
-
|
| 748 |
-
# Criar trainer
|
| 749 |
-
trainer = CidadaoTrainer(model, tokenizer, config)
|
| 750 |
-
|
| 751 |
-
logger.info(f"✅ Pipeline criado - Modelo: {config.model_size}, Tarefas: {config.specialized_tasks}")
|
| 752 |
-
|
| 753 |
-
return model, trainer
|
| 754 |
|
| 755 |
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
Preparar dados de transparência para treinamento
|
| 759 |
-
|
| 760 |
-
Esta função seria expandida para processar dados reais do Portal da Transparência
|
| 761 |
-
"""
|
| 762 |
-
|
| 763 |
-
logger.info("📊 Preparando dados de transparência")
|
| 764 |
-
|
| 765 |
-
output_dir = Path(output_dir)
|
| 766 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 767 |
-
|
| 768 |
-
# Aqui você implementaria:
|
| 769 |
-
# 1. Conexão com Portal da Transparência API
|
| 770 |
-
# 2. Extração e limpeza de dados
|
| 771 |
-
# 3. Anotação de anomalias (semi-supervisionado)
|
| 772 |
-
# 4. Balanceamento de classes
|
| 773 |
-
# 5. Divisão train/val/test
|
| 774 |
-
|
| 775 |
-
# Por enquanto, criar dados sintéticos
|
| 776 |
-
logger.info("⚠️ Usando dados sintéticos para demonstração")
|
| 777 |
-
|
| 778 |
-
# Implementação completa seria conectada aos dados reais
|
| 779 |
-
sample_data = {
|
| 780 |
-
"train": output_dir / "train.json",
|
| 781 |
-
"val": output_dir / "val.json",
|
| 782 |
-
"test": output_dir / "test.json"
|
| 783 |
-
}
|
| 784 |
-
|
| 785 |
-
return sample_data
|
| 786 |
|
| 787 |
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
# Configurar logging
|
| 792 |
-
logging.basicConfig(level=logging.INFO)
|
| 793 |
-
|
| 794 |
-
# Configuração de treinamento
|
| 795 |
-
config = TrainingConfig(
|
| 796 |
-
experiment_name="cidadao-gpt-transparency-v1",
|
| 797 |
-
num_epochs=5,
|
| 798 |
-
batch_size=4, # Reduzido para teste
|
| 799 |
-
learning_rate=2e-5,
|
| 800 |
-
use_wandb=False, # Desabilitar para teste
|
| 801 |
-
output_dir="./models/cidadao-gpt-test"
|
| 802 |
-
)
|
| 803 |
-
|
| 804 |
-
# Criar pipeline
|
| 805 |
-
model, trainer = create_training_pipeline(
|
| 806 |
-
data_path="./data/transparency_data.json",
|
| 807 |
-
config=config
|
| 808 |
-
)
|
| 809 |
-
|
| 810 |
-
print("🤖 Cidadão.AI Training Pipeline criado com sucesso!")
|
| 811 |
-
print(f"📊 Modelo: {config.model_size}")
|
| 812 |
-
print(f"🎯 Tarefas especializadas: {config.specialized_tasks}")
|
| 813 |
-
print(f"💾 Diretório de saída: {config.output_dir}")
|
|
|
|
| 1 |
"""
|
| 2 |
+
ML Training Pipeline for Cidadão.AI
|
| 3 |
|
| 4 |
+
This module provides a comprehensive training pipeline for ML models
|
| 5 |
+
used in anomaly detection, fraud detection, and pattern recognition.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
import os
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
+
import pickle
|
| 15 |
+
import joblib
|
| 16 |
+
import numpy as np
|
| 17 |
+
from sklearn.model_selection import train_test_split, cross_val_score
|
| 18 |
+
from sklearn.metrics import (
|
| 19 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 20 |
+
roc_auc_score, confusion_matrix, classification_report
|
| 21 |
+
)
|
| 22 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 23 |
+
from sklearn.ensemble import IsolationForest, RandomForestClassifier
|
| 24 |
+
from sklearn.svm import OneClassSVM
|
| 25 |
+
from sklearn.neighbors import LocalOutlierFactor
|
| 26 |
+
import mlflow
|
| 27 |
+
import mlflow.sklearn
|
| 28 |
+
from mlflow.tracking import MlflowClient
|
| 29 |
|
| 30 |
+
from src.core import get_logger, settings
|
| 31 |
+
from src.core.exceptions import CidadaoAIError
|
| 32 |
+
from src.infrastructure.cache.redis_client import get_redis_client
|
| 33 |
+
from src.models.ml_models import AnomalyDetectorModel
|
| 34 |
|
| 35 |
|
| 36 |
+
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
+
class MLTrainingPipeline:
|
| 40 |
+
"""
|
| 41 |
+
Comprehensive ML training pipeline with versioning and tracking.
|
| 42 |
+
|
| 43 |
+
Features:
|
| 44 |
+
- Multiple algorithm support
|
| 45 |
+
- Automatic hyperparameter tuning
|
| 46 |
+
- Model versioning with MLflow
|
| 47 |
+
- Performance tracking
|
| 48 |
+
- A/B testing support
|
| 49 |
+
"""
|
| 50 |
|
| 51 |
+
def __init__(self, experiment_name: str = "cidadao_ai_models"):
|
| 52 |
+
"""Initialize the training pipeline."""
|
| 53 |
+
self.experiment_name = experiment_name
|
| 54 |
+
self.mlflow_client = None
|
| 55 |
+
self.models_dir = Path(settings.get("ML_MODELS_DIR", "./models"))
|
| 56 |
+
self.models_dir.mkdir(exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# Supported algorithms
|
| 59 |
+
self.algorithms = {
|
| 60 |
+
"isolation_forest": IsolationForest,
|
| 61 |
+
"one_class_svm": OneClassSVM,
|
| 62 |
+
"random_forest": RandomForestClassifier,
|
| 63 |
+
"local_outlier_factor": LocalOutlierFactor
|
| 64 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# Model registry
|
| 67 |
+
self.model_registry = {}
|
| 68 |
+
self._initialize_mlflow()
|
| 69 |
+
|
| 70 |
+
def _initialize_mlflow(self):
|
| 71 |
+
"""Initialize MLflow tracking."""
|
| 72 |
+
try:
|
| 73 |
+
mlflow.set_tracking_uri(settings.get("MLFLOW_TRACKING_URI", "file:./mlruns"))
|
| 74 |
+
mlflow.set_experiment(self.experiment_name)
|
| 75 |
+
self.mlflow_client = MlflowClient()
|
| 76 |
+
logger.info(f"MLflow initialized with experiment: {self.experiment_name}")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.warning(f"MLflow initialization failed: {e}. Using local tracking.")
|
| 79 |
+
|
| 80 |
+
async def train_model(
|
| 81 |
+
self,
|
| 82 |
+
model_type: str,
|
| 83 |
+
algorithm: str,
|
| 84 |
+
X_train: np.ndarray,
|
| 85 |
+
y_train: Optional[np.ndarray] = None,
|
| 86 |
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
| 87 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 88 |
+
) -> Dict[str, Any]:
|
| 89 |
+
"""
|
| 90 |
+
Train a model with the specified algorithm.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
model_type: Type of model (anomaly, fraud, pattern)
|
| 94 |
+
algorithm: Algorithm to use
|
| 95 |
+
X_train: Training features
|
| 96 |
+
y_train: Training labels (optional for unsupervised)
|
| 97 |
+
hyperparameters: Model hyperparameters
|
| 98 |
+
metadata: Additional metadata
|
| 99 |
|
| 100 |
+
Returns:
|
| 101 |
+
Training results with model info
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
logger.info(f"Starting training for {model_type} with {algorithm}")
|
| 105 |
+
|
| 106 |
+
# Start MLflow run
|
| 107 |
+
with mlflow.start_run(run_name=f"{model_type}_{algorithm}_{datetime.now().isoformat()}"):
|
| 108 |
+
# Log parameters
|
| 109 |
+
mlflow.log_param("model_type", model_type)
|
| 110 |
+
mlflow.log_param("algorithm", algorithm)
|
| 111 |
+
mlflow.log_param("n_samples", X_train.shape[0])
|
| 112 |
+
mlflow.log_param("n_features", X_train.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
if hyperparameters:
|
| 115 |
+
for key, value in hyperparameters.items():
|
| 116 |
+
mlflow.log_param(f"param_{key}", value)
|
| 117 |
|
| 118 |
+
# Create and train model
|
| 119 |
+
model = await self._create_model(algorithm, hyperparameters)
|
| 120 |
+
|
| 121 |
+
# Train based on supervised/unsupervised
|
| 122 |
+
if y_train is not None:
|
| 123 |
+
# Supervised training
|
| 124 |
+
X_tr, X_val, y_tr, y_val = train_test_split(
|
| 125 |
+
X_train, y_train, test_size=0.2, random_state=42
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
model.fit(X_tr, y_tr)
|
| 129 |
+
|
| 130 |
+
# Evaluate
|
| 131 |
+
y_pred = model.predict(X_val)
|
| 132 |
+
metrics = self._calculate_metrics(y_val, y_pred)
|
| 133 |
+
|
| 134 |
+
# Cross-validation
|
| 135 |
+
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
|
| 136 |
+
metrics["cv_score_mean"] = cv_scores.mean()
|
| 137 |
+
metrics["cv_score_std"] = cv_scores.std()
|
| 138 |
+
|
| 139 |
+
else:
|
| 140 |
+
# Unsupervised training
|
| 141 |
+
model.fit(X_train)
|
| 142 |
+
|
| 143 |
+
# Evaluate with anomaly scores
|
| 144 |
+
if hasattr(model, 'score_samples'):
|
| 145 |
+
anomaly_scores = model.score_samples(X_train)
|
| 146 |
+
metrics = {
|
| 147 |
+
"anomaly_score_mean": float(np.mean(anomaly_scores)),
|
| 148 |
+
"anomaly_score_std": float(np.std(anomaly_scores)),
|
| 149 |
+
"anomaly_score_min": float(np.min(anomaly_scores)),
|
| 150 |
+
"anomaly_score_max": float(np.max(anomaly_scores))
|
| 151 |
+
}
|
| 152 |
+
else:
|
| 153 |
+
metrics = {"training_complete": True}
|
| 154 |
+
|
| 155 |
+
# Log metrics
|
| 156 |
+
for metric_name, metric_value in metrics.items():
|
| 157 |
+
mlflow.log_metric(metric_name, metric_value)
|
| 158 |
+
|
| 159 |
+
# Save model
|
| 160 |
+
model_path = await self._save_model(
|
| 161 |
+
model, model_type, algorithm, metrics, metadata
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Log model to MLflow
|
| 165 |
+
mlflow.sklearn.log_model(
|
| 166 |
+
model,
|
| 167 |
+
f"{model_type}_{algorithm}",
|
| 168 |
+
registered_model_name=f"{model_type}_{algorithm}_model"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Create model version
|
| 172 |
+
version = await self._create_model_version(
|
| 173 |
+
model_type, algorithm, model_path, metrics
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return {
|
| 177 |
+
"success": True,
|
| 178 |
+
"model_id": version["model_id"],
|
| 179 |
+
"version": version["version"],
|
| 180 |
+
"metrics": metrics,
|
| 181 |
+
"model_path": model_path,
|
| 182 |
+
"run_id": mlflow.active_run().info.run_id
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Training failed: {str(e)}")
|
| 187 |
+
return {
|
| 188 |
+
"success": False,
|
| 189 |
+
"error": str(e),
|
| 190 |
+
"model_id": None
|
| 191 |
+
}
|
| 192 |
|
| 193 |
+
async def _create_model(
|
| 194 |
+
self,
|
| 195 |
+
algorithm: str,
|
| 196 |
+
hyperparameters: Optional[Dict[str, Any]] = None
|
| 197 |
+
) -> Any:
|
| 198 |
+
"""Create a model instance with hyperparameters."""
|
| 199 |
+
if algorithm not in self.algorithms:
|
| 200 |
+
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
| 201 |
+
|
| 202 |
+
model_class = self.algorithms[algorithm]
|
| 203 |
+
|
| 204 |
+
# Default hyperparameters
|
| 205 |
+
default_params = {
|
| 206 |
+
"isolation_forest": {
|
| 207 |
+
"contamination": 0.1,
|
| 208 |
+
"random_state": 42,
|
| 209 |
+
"n_estimators": 100
|
| 210 |
+
},
|
| 211 |
+
"one_class_svm": {
|
| 212 |
+
"gamma": 0.001,
|
| 213 |
+
"nu": 0.05,
|
| 214 |
+
"kernel": "rbf"
|
| 215 |
+
},
|
| 216 |
+
"random_forest": {
|
| 217 |
+
"n_estimators": 100,
|
| 218 |
+
"random_state": 42,
|
| 219 |
+
"max_depth": 10
|
| 220 |
+
},
|
| 221 |
+
"local_outlier_factor": {
|
| 222 |
+
"contamination": 0.1,
|
| 223 |
+
"n_neighbors": 20
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
|
| 227 |
+
# Merge with provided hyperparameters
|
| 228 |
+
params = default_params.get(algorithm, {})
|
| 229 |
+
if hyperparameters:
|
| 230 |
+
params.update(hyperparameters)
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
return model_class(**params)
|
| 233 |
|
| 234 |
+
def _calculate_metrics(
|
| 235 |
+
self,
|
| 236 |
+
y_true: np.ndarray,
|
| 237 |
+
y_pred: np.ndarray,
|
| 238 |
+
y_proba: Optional[np.ndarray] = None
|
| 239 |
+
) -> Dict[str, float]:
|
| 240 |
+
"""Calculate comprehensive metrics for model evaluation."""
|
| 241 |
+
metrics = {
|
| 242 |
+
"accuracy": float(accuracy_score(y_true, y_pred)),
|
| 243 |
+
"precision": float(precision_score(y_true, y_pred, average='weighted')),
|
| 244 |
+
"recall": float(recall_score(y_true, y_pred, average='weighted')),
|
| 245 |
+
"f1_score": float(f1_score(y_true, y_pred, average='weighted'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
}
|
| 247 |
|
| 248 |
+
# Add ROC-AUC if probabilities available
|
| 249 |
+
if y_proba is not None and len(np.unique(y_true)) == 2:
|
| 250 |
+
metrics["roc_auc"] = float(roc_auc_score(y_true, y_proba[:, 1]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
return metrics
|
| 253 |
+
|
| 254 |
+
async def _save_model(
|
| 255 |
+
self,
|
| 256 |
+
model: Any,
|
| 257 |
+
model_type: str,
|
| 258 |
+
algorithm: str,
|
| 259 |
+
metrics: Dict[str, Any],
|
| 260 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 261 |
+
) -> str:
|
| 262 |
+
"""Save trained model to disk."""
|
| 263 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 264 |
+
model_filename = f"{model_type}_{algorithm}_{timestamp}.pkl"
|
| 265 |
+
model_path = self.models_dir / model_filename
|
| 266 |
+
|
| 267 |
+
# Create model package
|
| 268 |
+
model_package = {
|
| 269 |
+
"model": model,
|
| 270 |
+
"model_type": model_type,
|
| 271 |
+
"algorithm": algorithm,
|
| 272 |
+
"metrics": metrics,
|
| 273 |
+
"metadata": metadata or {},
|
| 274 |
+
"created_at": datetime.now().isoformat(),
|
| 275 |
+
"version": timestamp
|
| 276 |
}
|
| 277 |
|
| 278 |
+
# Save with joblib for better compression
|
| 279 |
+
joblib.dump(model_package, model_path)
|
| 280 |
+
logger.info(f"Model saved to: {model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
return str(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
async def _create_model_version(
|
| 285 |
self,
|
| 286 |
+
model_type: str,
|
| 287 |
+
algorithm: str,
|
| 288 |
+
model_path: str,
|
| 289 |
+
metrics: Dict[str, Any]
|
| 290 |
+
) -> Dict[str, Any]:
|
| 291 |
+
"""Create a versioned model entry in the registry."""
|
| 292 |
+
model_id = f"{model_type}_{algorithm}"
|
| 293 |
+
|
| 294 |
+
# Get or create model entry
|
| 295 |
+
if model_id not in self.model_registry:
|
| 296 |
+
self.model_registry[model_id] = {
|
| 297 |
+
"versions": [],
|
| 298 |
+
"current_version": None,
|
| 299 |
+
"created_at": datetime.now().isoformat()
|
| 300 |
+
}
|
| 301 |
|
| 302 |
+
# Add new version
|
| 303 |
+
version = {
|
| 304 |
+
"version": len(self.model_registry[model_id]["versions"]) + 1,
|
| 305 |
+
"path": model_path,
|
| 306 |
+
"metrics": metrics,
|
| 307 |
+
"created_at": datetime.now().isoformat(),
|
| 308 |
+
"status": "staging" # staging, production, archived
|
| 309 |
+
}
|
| 310 |
|
| 311 |
+
self.model_registry[model_id]["versions"].append(version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
# Save registry to Redis
|
| 314 |
+
redis_client = await get_redis_client()
|
| 315 |
+
await redis_client.set(
|
| 316 |
+
f"ml_model_registry:{model_id}",
|
| 317 |
+
json.dumps(self.model_registry[model_id]),
|
| 318 |
+
ex=86400 * 30 # 30 days
|
| 319 |
+
)
|
| 320 |
|
| 321 |
+
return {
|
| 322 |
+
"model_id": model_id,
|
| 323 |
+
"version": version["version"]
|
|
|
|
|
|
|
| 324 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
+
async def load_model(
|
| 327 |
self,
|
| 328 |
+
model_id: str,
|
| 329 |
+
version: Optional[int] = None
|
| 330 |
+
) -> Tuple[Any, Dict[str, Any]]:
|
| 331 |
+
"""
|
| 332 |
+
Load a model from the registry.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
model_id: Model identifier
|
| 336 |
+
version: Specific version (latest if None)
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Tuple of (model, metadata)
|
| 340 |
+
"""
|
| 341 |
+
# Try to load from Redis first
|
| 342 |
+
redis_client = await get_redis_client()
|
| 343 |
+
registry_data = await redis_client.get(f"ml_model_registry:{model_id}")
|
| 344 |
+
|
| 345 |
+
if registry_data:
|
| 346 |
+
registry = json.loads(registry_data)
|
| 347 |
+
elif model_id in self.model_registry:
|
| 348 |
+
registry = self.model_registry[model_id]
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Model {model_id} not found in registry")
|
| 351 |
+
|
| 352 |
+
# Get version
|
| 353 |
+
if version is None:
|
| 354 |
+
# Get latest production version or latest version
|
| 355 |
+
prod_versions = [
|
| 356 |
+
v for v in registry["versions"]
|
| 357 |
+
if v.get("status") == "production"
|
| 358 |
+
]
|
| 359 |
+
|
| 360 |
+
if prod_versions:
|
| 361 |
+
version_data = max(prod_versions, key=lambda v: v["version"])
|
| 362 |
+
else:
|
| 363 |
+
version_data = max(registry["versions"], key=lambda v: v["version"])
|
| 364 |
+
else:
|
| 365 |
+
version_data = next(
|
| 366 |
+
(v for v in registry["versions"] if v["version"] == version),
|
| 367 |
+
None
|
| 368 |
)
|
| 369 |
+
|
| 370 |
+
if not version_data:
|
| 371 |
+
raise ValueError(f"Version {version} not found for model {model_id}")
|
| 372 |
|
| 373 |
+
# Load model
|
| 374 |
+
model_package = joblib.load(version_data["path"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
+
return model_package["model"], model_package
|
| 377 |
+
|
| 378 |
+
async def promote_model(
|
| 379 |
+
self,
|
| 380 |
+
model_id: str,
|
| 381 |
+
version: int,
|
| 382 |
+
status: str = "production"
|
| 383 |
+
) -> bool:
|
| 384 |
+
"""
|
| 385 |
+
Promote a model version to production.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
model_id: Model identifier
|
| 389 |
+
version: Version to promote
|
| 390 |
+
status: New status (production, staging, archived)
|
| 391 |
+
"""
|
| 392 |
+
try:
|
| 393 |
+
# Load registry
|
| 394 |
+
redis_client = await get_redis_client()
|
| 395 |
+
registry_data = await redis_client.get(f"ml_model_registry:{model_id}")
|
| 396 |
|
| 397 |
+
if registry_data:
|
| 398 |
+
registry = json.loads(registry_data)
|
| 399 |
+
else:
|
| 400 |
+
registry = self.model_registry.get(model_id)
|
| 401 |
|
| 402 |
+
if not registry:
|
| 403 |
+
raise ValueError(f"Model {model_id} not found")
|
| 404 |
+
|
| 405 |
+
# Update version status
|
| 406 |
+
for v in registry["versions"]:
|
| 407 |
+
if v["version"] == version:
|
| 408 |
+
# Archive current production if promoting to production
|
| 409 |
+
if status == "production":
|
| 410 |
+
for other_v in registry["versions"]:
|
| 411 |
+
if other_v.get("status") == "production":
|
| 412 |
+
other_v["status"] = "archived"
|
| 413 |
|
| 414 |
+
v["status"] = status
|
| 415 |
+
v["promoted_at"] = datetime.now().isoformat()
|
| 416 |
break
|
| 417 |
|
| 418 |
+
# Save updated registry
|
| 419 |
+
self.model_registry[model_id] = registry
|
| 420 |
+
await redis_client.set(
|
| 421 |
+
f"ml_model_registry:{model_id}",
|
| 422 |
+
json.dumps(registry),
|
| 423 |
+
ex=86400 * 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
)
|
| 425 |
|
| 426 |
+
logger.info(f"Promoted {model_id} v{version} to {status}")
|
| 427 |
+
return True
|
| 428 |
|
| 429 |
+
except Exception as e:
|
| 430 |
+
logger.error(f"Failed to promote model: {e}")
|
| 431 |
+
return False
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
+
async def get_model_metrics(
|
| 434 |
+
self,
|
| 435 |
+
model_id: str,
|
| 436 |
+
version: Optional[int] = None
|
| 437 |
+
) -> Dict[str, Any]:
|
| 438 |
+
"""Get metrics for a specific model version."""
|
| 439 |
+
_, metadata = await self.load_model(model_id, version)
|
| 440 |
+
return metadata.get("metrics", {})
|
| 441 |
+
|
| 442 |
+
async def compare_models(
|
| 443 |
+
self,
|
| 444 |
+
model_ids: List[Tuple[str, Optional[int]]],
|
| 445 |
+
test_data: np.ndarray,
|
| 446 |
+
test_labels: Optional[np.ndarray] = None
|
| 447 |
+
) -> Dict[str, Any]:
|
| 448 |
+
"""
|
| 449 |
+
Compare multiple models on the same test data.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
model_ids: List of (model_id, version) tuples
|
| 453 |
+
test_data: Test features
|
| 454 |
+
test_labels: Test labels (if available)
|
| 455 |
|
| 456 |
+
Returns:
|
| 457 |
+
Comparison results
|
| 458 |
+
"""
|
| 459 |
+
results = {}
|
| 460 |
+
|
| 461 |
+
for model_id, version in model_ids:
|
| 462 |
+
try:
|
| 463 |
+
model, metadata = await self.load_model(model_id, version)
|
| 464 |
|
| 465 |
+
# Make predictions
|
| 466 |
+
predictions = model.predict(test_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
+
result = {
|
| 469 |
+
"model_id": model_id,
|
| 470 |
+
"version": version or "latest",
|
| 471 |
+
"algorithm": metadata.get("algorithm"),
|
| 472 |
+
"training_metrics": metadata.get("metrics", {})
|
| 473 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
+
# Calculate test metrics if labels available
|
| 476 |
+
if test_labels is not None:
|
| 477 |
+
test_metrics = self._calculate_metrics(test_labels, predictions)
|
| 478 |
+
result["test_metrics"] = test_metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
+
# Add anomaly scores for unsupervised models
|
| 481 |
+
if hasattr(model, 'score_samples'):
|
| 482 |
+
scores = model.score_samples(test_data)
|
| 483 |
+
result["anomaly_scores"] = {
|
| 484 |
+
"mean": float(np.mean(scores)),
|
| 485 |
+
"std": float(np.std(scores)),
|
| 486 |
+
"percentiles": {
|
| 487 |
+
"10": float(np.percentile(scores, 10)),
|
| 488 |
+
"50": float(np.percentile(scores, 50)),
|
| 489 |
+
"90": float(np.percentile(scores, 90))
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
|
| 493 |
+
results[f"{model_id}_v{version or 'latest'}"] = result
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logger.error(f"Failed to evaluate {model_id}: {e}")
|
| 497 |
+
results[f"{model_id}_v{version or 'latest'}"] = {
|
| 498 |
+
"error": str(e)
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
return results
|
| 502 |
+
|
| 503 |
+
async def cleanup_old_models(self, days: int = 30) -> int:
|
| 504 |
+
"""Remove models older than specified days."""
|
| 505 |
+
count = 0
|
| 506 |
+
cutoff_date = datetime.now().timestamp() - (days * 86400)
|
| 507 |
+
|
| 508 |
+
for model_file in self.models_dir.glob("*.pkl"):
|
| 509 |
+
if model_file.stat().st_mtime < cutoff_date:
|
| 510 |
+
model_file.unlink()
|
| 511 |
+
count += 1
|
| 512 |
+
logger.info(f"Removed old model: {model_file}")
|
| 513 |
+
|
| 514 |
+
return count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
|
| 517 |
+
# Global training pipeline instance
|
| 518 |
+
training_pipeline = MLTrainingPipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
|
| 521 |
+
async def get_training_pipeline() -> MLTrainingPipeline:
|
| 522 |
+
"""Get the global training pipeline instance."""
|
| 523 |
+
return training_pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/unit/ml/__init__.py
ADDED
|
File without changes
|
tests/unit/ml/test_training_pipeline.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for ML Training Pipeline
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import asyncio
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from src.ml.training_pipeline import MLTrainingPipeline, training_pipeline
|
| 13 |
+
from src.ml.ab_testing import ABTestFramework, ABTestStatus, TrafficAllocationStrategy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestMLTrainingPipeline:
|
| 17 |
+
"""Test suite for ML training pipeline."""
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def pipeline(self):
|
| 21 |
+
"""Create a test pipeline instance."""
|
| 22 |
+
return MLTrainingPipeline(experiment_name="test_experiment")
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def sample_data(self):
|
| 26 |
+
"""Generate sample training data."""
|
| 27 |
+
X_train = np.random.randn(100, 10)
|
| 28 |
+
y_train = np.random.choice([0, 1], size=100)
|
| 29 |
+
return X_train, y_train
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_pipeline_initialization(self, pipeline):
|
| 33 |
+
"""Test pipeline initialization."""
|
| 34 |
+
assert pipeline.experiment_name == "test_experiment"
|
| 35 |
+
assert pipeline.models_dir.exists()
|
| 36 |
+
assert len(pipeline.algorithms) > 0
|
| 37 |
+
assert "isolation_forest" in pipeline.algorithms
|
| 38 |
+
|
| 39 |
+
@pytest.mark.asyncio
|
| 40 |
+
async def test_train_unsupervised_model(self, pipeline, sample_data):
|
| 41 |
+
"""Test training an unsupervised model."""
|
| 42 |
+
X_train, _ = sample_data
|
| 43 |
+
|
| 44 |
+
with patch('mlflow.start_run'), \
|
| 45 |
+
patch('mlflow.log_param'), \
|
| 46 |
+
patch('mlflow.log_metric'), \
|
| 47 |
+
patch('mlflow.sklearn.log_model'):
|
| 48 |
+
|
| 49 |
+
result = await pipeline.train_model(
|
| 50 |
+
model_type="anomaly",
|
| 51 |
+
algorithm="isolation_forest",
|
| 52 |
+
X_train=X_train,
|
| 53 |
+
hyperparameters={"contamination": 0.1}
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
assert result["success"] is True
|
| 57 |
+
assert result["model_id"] == "anomaly_isolation_forest"
|
| 58 |
+
assert result["version"] == 1
|
| 59 |
+
assert "metrics" in result
|
| 60 |
+
assert "model_path" in result
|
| 61 |
+
|
| 62 |
+
@pytest.mark.asyncio
|
| 63 |
+
async def test_train_supervised_model(self, pipeline, sample_data):
|
| 64 |
+
"""Test training a supervised model."""
|
| 65 |
+
X_train, y_train = sample_data
|
| 66 |
+
|
| 67 |
+
with patch('mlflow.start_run'), \
|
| 68 |
+
patch('mlflow.log_param'), \
|
| 69 |
+
patch('mlflow.log_metric'), \
|
| 70 |
+
patch('mlflow.sklearn.log_model'):
|
| 71 |
+
|
| 72 |
+
result = await pipeline.train_model(
|
| 73 |
+
model_type="fraud",
|
| 74 |
+
algorithm="random_forest",
|
| 75 |
+
X_train=X_train,
|
| 76 |
+
y_train=y_train,
|
| 77 |
+
hyperparameters={"n_estimators": 50}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
assert result["success"] is True
|
| 81 |
+
assert result["model_id"] == "fraud_random_forest"
|
| 82 |
+
assert "accuracy" in result["metrics"]
|
| 83 |
+
assert "precision" in result["metrics"]
|
| 84 |
+
assert "recall" in result["metrics"]
|
| 85 |
+
assert "f1_score" in result["metrics"]
|
| 86 |
+
|
| 87 |
+
@pytest.mark.asyncio
|
| 88 |
+
async def test_model_versioning(self, pipeline, sample_data):
|
| 89 |
+
"""Test model versioning system."""
|
| 90 |
+
X_train, _ = sample_data
|
| 91 |
+
|
| 92 |
+
with patch('mlflow.start_run'), \
|
| 93 |
+
patch('mlflow.log_param'), \
|
| 94 |
+
patch('mlflow.log_metric'), \
|
| 95 |
+
patch('mlflow.sklearn.log_model'), \
|
| 96 |
+
patch.object(pipeline, '_save_model') as mock_save:
|
| 97 |
+
|
| 98 |
+
# Mock save to return a path
|
| 99 |
+
mock_save.return_value = "/models/test_model.pkl"
|
| 100 |
+
|
| 101 |
+
# Train first version
|
| 102 |
+
result1 = await pipeline.train_model(
|
| 103 |
+
model_type="anomaly",
|
| 104 |
+
algorithm="isolation_forest",
|
| 105 |
+
X_train=X_train
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Train second version
|
| 109 |
+
result2 = await pipeline.train_model(
|
| 110 |
+
model_type="anomaly",
|
| 111 |
+
algorithm="isolation_forest",
|
| 112 |
+
X_train=X_train
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
assert result1["version"] == 1
|
| 116 |
+
assert result2["version"] == 2
|
| 117 |
+
assert pipeline.model_registry["anomaly_isolation_forest"]["versions"].__len__() == 2
|
| 118 |
+
|
| 119 |
+
@pytest.mark.asyncio
|
| 120 |
+
async def test_load_model(self, pipeline, sample_data):
|
| 121 |
+
"""Test loading a model from registry."""
|
| 122 |
+
X_train, _ = sample_data
|
| 123 |
+
|
| 124 |
+
# Create a mock model
|
| 125 |
+
mock_model = MagicMock()
|
| 126 |
+
model_package = {
|
| 127 |
+
"model": mock_model,
|
| 128 |
+
"model_type": "anomaly",
|
| 129 |
+
"algorithm": "isolation_forest",
|
| 130 |
+
"metrics": {"score": 0.95},
|
| 131 |
+
"created_at": datetime.now().isoformat()
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
with patch('joblib.load', return_value=model_package), \
|
| 135 |
+
patch.object(pipeline, 'model_registry', {
|
| 136 |
+
"anomaly_isolation_forest": {
|
| 137 |
+
"versions": [{
|
| 138 |
+
"version": 1,
|
| 139 |
+
"path": "/models/test.pkl",
|
| 140 |
+
"status": "production"
|
| 141 |
+
}]
|
| 142 |
+
}
|
| 143 |
+
}):
|
| 144 |
+
|
| 145 |
+
model, metadata = await pipeline.load_model("anomaly_isolation_forest")
|
| 146 |
+
|
| 147 |
+
assert model == mock_model
|
| 148 |
+
assert metadata["model_type"] == "anomaly"
|
| 149 |
+
assert metadata["algorithm"] == "isolation_forest"
|
| 150 |
+
|
| 151 |
+
@pytest.mark.asyncio
|
| 152 |
+
async def test_promote_model(self, pipeline):
|
| 153 |
+
"""Test promoting a model to production."""
|
| 154 |
+
with patch('src.infrastructure.cache.redis_client.get_redis_client') as mock_redis:
|
| 155 |
+
mock_redis_client = AsyncMock()
|
| 156 |
+
mock_redis_client.get.return_value = json.dumps({
|
| 157 |
+
"versions": [
|
| 158 |
+
{"version": 1, "status": "staging"},
|
| 159 |
+
{"version": 2, "status": "staging"}
|
| 160 |
+
]
|
| 161 |
+
})
|
| 162 |
+
mock_redis_client.set = AsyncMock()
|
| 163 |
+
mock_redis.return_value = mock_redis_client
|
| 164 |
+
|
| 165 |
+
success = await pipeline.promote_model("test_model", 2, "production")
|
| 166 |
+
|
| 167 |
+
assert success is True
|
| 168 |
+
mock_redis_client.set.assert_called_once()
|
| 169 |
+
|
| 170 |
+
@pytest.mark.asyncio
|
| 171 |
+
async def test_compare_models(self, pipeline):
|
| 172 |
+
"""Test comparing multiple models."""
|
| 173 |
+
test_data = np.random.randn(50, 10)
|
| 174 |
+
test_labels = np.random.choice([0, 1], size=50)
|
| 175 |
+
|
| 176 |
+
# Mock models
|
| 177 |
+
mock_model1 = MagicMock()
|
| 178 |
+
mock_model1.predict.return_value = np.ones(50)
|
| 179 |
+
mock_model1.score_samples = MagicMock(return_value=np.random.randn(50))
|
| 180 |
+
|
| 181 |
+
mock_model2 = MagicMock()
|
| 182 |
+
mock_model2.predict.return_value = np.zeros(50)
|
| 183 |
+
|
| 184 |
+
with patch.object(pipeline, 'load_model') as mock_load:
|
| 185 |
+
mock_load.side_effect = [
|
| 186 |
+
(mock_model1, {"algorithm": "isolation_forest", "metrics": {}}),
|
| 187 |
+
(mock_model2, {"algorithm": "random_forest", "metrics": {}})
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
results = await pipeline.compare_models(
|
| 191 |
+
[("model1", 1), ("model2", 2)],
|
| 192 |
+
test_data,
|
| 193 |
+
test_labels
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
assert "model1_v1" in results
|
| 197 |
+
assert "model2_v2" in results
|
| 198 |
+
assert "test_metrics" in results["model1_v1"]
|
| 199 |
+
assert "anomaly_scores" in results["model1_v1"]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TestABTestingFramework:
|
| 203 |
+
"""Test suite for A/B testing framework."""
|
| 204 |
+
|
| 205 |
+
@pytest.fixture
|
| 206 |
+
def ab_framework(self):
|
| 207 |
+
"""Create a test A/B testing framework."""
|
| 208 |
+
return ABTestFramework()
|
| 209 |
+
|
| 210 |
+
@pytest.mark.asyncio
|
| 211 |
+
async def test_create_ab_test(self, ab_framework):
|
| 212 |
+
"""Test creating an A/B test."""
|
| 213 |
+
with patch.object(training_pipeline, 'load_model') as mock_load, \
|
| 214 |
+
patch('src.infrastructure.cache.redis_client.get_redis_client') as mock_redis:
|
| 215 |
+
|
| 216 |
+
# Mock model loading
|
| 217 |
+
mock_load.return_value = (MagicMock(), {})
|
| 218 |
+
|
| 219 |
+
# Mock Redis
|
| 220 |
+
mock_redis_client = AsyncMock()
|
| 221 |
+
mock_redis_client.set = AsyncMock()
|
| 222 |
+
mock_redis.return_value = mock_redis_client
|
| 223 |
+
|
| 224 |
+
test_config = await ab_framework.create_test(
|
| 225 |
+
test_name="test_ab",
|
| 226 |
+
model_a=("model1", 1),
|
| 227 |
+
model_b=("model2", 1),
|
| 228 |
+
traffic_split=(0.6, 0.4),
|
| 229 |
+
success_metric="accuracy"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
assert test_config["test_name"] == "test_ab"
|
| 233 |
+
assert test_config["traffic_split"] == (0.6, 0.4)
|
| 234 |
+
assert test_config["status"] == ABTestStatus.DRAFT.value
|
| 235 |
+
assert test_config["model_a"]["model_id"] == "model1"
|
| 236 |
+
assert test_config["model_b"]["model_id"] == "model2"
|
| 237 |
+
|
| 238 |
+
@pytest.mark.asyncio
|
| 239 |
+
async def test_start_ab_test(self, ab_framework):
|
| 240 |
+
"""Test starting an A/B test."""
|
| 241 |
+
# Create test first
|
| 242 |
+
test_config = {
|
| 243 |
+
"test_name": "test_ab",
|
| 244 |
+
"status": ABTestStatus.DRAFT.value,
|
| 245 |
+
"start_time": None
|
| 246 |
+
}
|
| 247 |
+
ab_framework.active_tests["test_ab"] = test_config
|
| 248 |
+
|
| 249 |
+
with patch('src.infrastructure.cache.redis_client.get_redis_client') as mock_redis:
|
| 250 |
+
mock_redis_client = AsyncMock()
|
| 251 |
+
mock_redis_client.set = AsyncMock()
|
| 252 |
+
mock_redis.return_value = mock_redis_client
|
| 253 |
+
|
| 254 |
+
success = await ab_framework.start_test("test_ab")
|
| 255 |
+
|
| 256 |
+
assert success is True
|
| 257 |
+
assert test_config["status"] == ABTestStatus.RUNNING.value
|
| 258 |
+
assert test_config["start_time"] is not None
|
| 259 |
+
|
| 260 |
+
@pytest.mark.asyncio
|
| 261 |
+
async def test_allocate_model_random(self, ab_framework):
|
| 262 |
+
"""Test random model allocation."""
|
| 263 |
+
test_config = {
|
| 264 |
+
"test_name": "test_ab",
|
| 265 |
+
"status": ABTestStatus.RUNNING.value,
|
| 266 |
+
"allocation_strategy": TrafficAllocationStrategy.RANDOM.value,
|
| 267 |
+
"traffic_split": (0.5, 0.5),
|
| 268 |
+
"model_a": {"model_id": "model1", "version": 1},
|
| 269 |
+
"model_b": {"model_id": "model2", "version": 1}
|
| 270 |
+
}
|
| 271 |
+
ab_framework.active_tests["test_ab"] = test_config
|
| 272 |
+
|
| 273 |
+
# Test multiple allocations
|
| 274 |
+
allocations = []
|
| 275 |
+
for _ in range(100):
|
| 276 |
+
model_id, version = await ab_framework.allocate_model("test_ab")
|
| 277 |
+
allocations.append(model_id)
|
| 278 |
+
|
| 279 |
+
# Should have both models allocated
|
| 280 |
+
assert "model1" in allocations
|
| 281 |
+
assert "model2" in allocations
|
| 282 |
+
|
| 283 |
+
@pytest.mark.asyncio
|
| 284 |
+
async def test_record_prediction(self, ab_framework):
|
| 285 |
+
"""Test recording prediction results."""
|
| 286 |
+
test_config = {
|
| 287 |
+
"test_name": "test_ab",
|
| 288 |
+
"status": ABTestStatus.RUNNING.value,
|
| 289 |
+
"allocation_strategy": TrafficAllocationStrategy.RANDOM.value,
|
| 290 |
+
"results": {
|
| 291 |
+
"model_a": {"predictions": 0, "successes": 0},
|
| 292 |
+
"model_b": {"predictions": 0, "successes": 0}
|
| 293 |
+
},
|
| 294 |
+
"minimum_sample_size": 10
|
| 295 |
+
}
|
| 296 |
+
ab_framework.active_tests["test_ab"] = test_config
|
| 297 |
+
|
| 298 |
+
with patch('src.infrastructure.cache.redis_client.get_redis_client') as mock_redis:
|
| 299 |
+
mock_redis_client = AsyncMock()
|
| 300 |
+
mock_redis_client.set = AsyncMock()
|
| 301 |
+
mock_redis.return_value = mock_redis_client
|
| 302 |
+
|
| 303 |
+
# Record some predictions
|
| 304 |
+
await ab_framework.record_prediction("test_ab", "model_a", True)
|
| 305 |
+
await ab_framework.record_prediction("test_ab", "model_a", False)
|
| 306 |
+
await ab_framework.record_prediction("test_ab", "model_b", True)
|
| 307 |
+
|
| 308 |
+
assert test_config["results"]["model_a"]["predictions"] == 2
|
| 309 |
+
assert test_config["results"]["model_a"]["successes"] == 1
|
| 310 |
+
assert test_config["results"]["model_b"]["predictions"] == 1
|
| 311 |
+
assert test_config["results"]["model_b"]["successes"] == 1
|
| 312 |
+
|
| 313 |
+
@pytest.mark.asyncio
|
| 314 |
+
async def test_analyze_test(self, ab_framework):
|
| 315 |
+
"""Test analyzing A/B test results."""
|
| 316 |
+
test_config = {
|
| 317 |
+
"test_name": "test_ab",
|
| 318 |
+
"results": {
|
| 319 |
+
"model_a": {"predictions": 1000, "successes": 520},
|
| 320 |
+
"model_b": {"predictions": 1000, "successes": 480}
|
| 321 |
+
},
|
| 322 |
+
"significance_level": 0.05
|
| 323 |
+
}
|
| 324 |
+
ab_framework.active_tests["test_ab"] = test_config
|
| 325 |
+
|
| 326 |
+
with patch('src.infrastructure.cache.redis_client.get_redis_client') as mock_redis:
|
| 327 |
+
mock_redis_client = AsyncMock()
|
| 328 |
+
mock_redis_client.set = AsyncMock()
|
| 329 |
+
mock_redis.return_value = mock_redis_client
|
| 330 |
+
|
| 331 |
+
analysis = await ab_framework.analyze_test("test_ab")
|
| 332 |
+
|
| 333 |
+
assert "model_a" in analysis
|
| 334 |
+
assert "model_b" in analysis
|
| 335 |
+
assert "p_value" in analysis
|
| 336 |
+
assert "significant" in analysis
|
| 337 |
+
assert "lift" in analysis
|
| 338 |
+
assert analysis["model_a"]["conversion_rate"] == 0.52
|
| 339 |
+
assert analysis["model_b"]["conversion_rate"] == 0.48
|
| 340 |
+
|
| 341 |
+
@pytest.mark.asyncio
|
| 342 |
+
async def test_thompson_sampling_allocation(self, ab_framework):
|
| 343 |
+
"""Test Thompson sampling allocation."""
|
| 344 |
+
test_config = {
|
| 345 |
+
"test_name": "test_ab",
|
| 346 |
+
"status": ABTestStatus.RUNNING.value,
|
| 347 |
+
"allocation_strategy": TrafficAllocationStrategy.THOMPSON_SAMPLING.value,
|
| 348 |
+
"thompson_params": {
|
| 349 |
+
"model_a": {"alpha": 10, "beta": 5},
|
| 350 |
+
"model_b": {"alpha": 5, "beta": 10}
|
| 351 |
+
},
|
| 352 |
+
"model_a": {"model_id": "model1", "version": 1},
|
| 353 |
+
"model_b": {"model_id": "model2", "version": 1}
|
| 354 |
+
}
|
| 355 |
+
ab_framework.active_tests["test_ab"] = test_config
|
| 356 |
+
|
| 357 |
+
# Test allocation - should favor model_a due to higher success rate
|
| 358 |
+
allocations = []
|
| 359 |
+
for _ in range(100):
|
| 360 |
+
model_id, _ = await ab_framework.allocate_model("test_ab")
|
| 361 |
+
allocations.append(model_id)
|
| 362 |
+
|
| 363 |
+
# Model 1 should be allocated more often
|
| 364 |
+
model1_count = allocations.count("model1")
|
| 365 |
+
model2_count = allocations.count("model2")
|
| 366 |
+
|
| 367 |
+
# Thompson sampling is probabilistic, but model1 should generally be favored
|
| 368 |
+
assert model1_count > 0
|
| 369 |
+
assert model2_count > 0
|