πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

ML Model Monitoring & Retraining Pipeline (Evidently + MLflow)

AI/ML ProjectsMLOps⭐ Premium

Advertisement

ML Model Monitoring & Retraining Pipeline

Evidently + MLflow + Grafana | Production MLOps

Advanced12+ HoursProduction-Ready

Project Overview

Problem Statement

ML models degrade over time due to data drift, concept drift, and changing business conditions. Without monitoring, degraded models silently make poor predictions, causing revenue loss and customer dissatisfaction. This system automatically detects drift, alerts stakeholders, and triggers retraining.

Objectives

  • Implement comprehensive data and model drift detection
  • Build automated retraining pipeline triggered by drift
  • Create dashboards for model health monitoring
  • Set up alerting for model performance degradation
  • Maintain full experiment lineage with MLflow
ComponentTechnology
Drift DetectionEvidently AI
Experiment TrackingMLflow
Monitoring DashboardGrafana + Prometheus
AlertingPagerDuty + Slack
Pipeline OrchestrationApache Airflow
Data StorePostgreSQL + S3
Model RegistryMLflow Model Registry

Architecture Diagram

Architecture Diagram
+-------------------------------------------------------------------+
|           Model Monitoring & Retraining Architecture               |
+-------------------------------------------------------------------+
|  +--------------+    +--------------+    +------------------+     |
|  | Production   |--->| Data/Model   |--->| Evidently AI     |     |
|  | Predictions  |    | Snapshot     |    | Drift Reports    |     |
+--------------+    +--------------+    +--------+---------+     |
|                                                  |               |
|                                                  v               |
|  +--------------+    +--------------+    +------------------+     |
|  |  Alerts      |<---|  Drift       |<---|  Threshold       |     |
|  |  (PagerDuty) |    |  Detector    |    |  Manager         |     |
|  +--------------+    +--------------+    +------------------+     |
|        |                                                   |     |
|        v                                                   v     |
|  +--------------+    +--------------+    +------------------+     |
|  |  Airflow     |--->|  Auto        |--->|  MLflow          |     |
|  |  Trigger     |    |  Retrain     |    |  Experiment Log  |     |
|  +--------------+    +--------------+    +------------------+     |
+-------------------------------------------------------------------+

Step-by-Step Implementation

Step 1: Environment Setup

mkdir model-monitoring && cd model-monitoring
pip install evidently mlflow prometheus-client grafana-api
pip install pandas numpy scikit-learn
pip install fastapi uvicorn psycopg2-binary
pip install apache-airflow great-expectations

Step 2: Evidently AI Drift Detection

Set up comprehensive monitoring for data quality, data drift, and model performance.

# src/monitoring/drift_detector.py
import pandas as pd
import numpy as np
from evidently import ColumnMapping
from evidently.report import Report
from evidently.metric_preset import (
    DataDriftPreset, DataQualityPreset, TargetDriftPreset,
    ClassificationPreset
)
from evidently.test_suite import TestSuite
from evidently.tests import *
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import json
import logging

logger = logging.getLogger(__name__)


@dataclass
class DriftConfig:
    drift_threshold: float = 0.05
    missing_threshold: float = 0.1
    alert_on_drift: bool = True
    reference_period_days: int = 30


class DriftDetector:
    def __init__(self, config: DriftConfig = None):
        self.config = config or DriftConfig()
        self.reference_data = None

    def set_reference_data(self, reference_df: pd.DataFrame, target_col: str = None):
        self.reference_data = reference_df
        self.target_col = target_col

    def create_column_mapping(self, df: pd.DataFrame) -> ColumnMapping:
        categorical_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
        numerical_cols = df.select_dtypes(include=[np.number]).columns.tolist()

        mapping = ColumnMapping(
            target=self.target_col if self.target_col in df.columns else None,
            numerical_features=[c for c in numerical_cols if c != self.target_col],
            categorical_features=categorical_cols,
        )
        return mapping

    def run_drift_analysis(
        self, current_df: pd.DataFrame,
        report_name: str = "drift_report"
    ) -> Dict:
        if self.reference_data is None:
            raise ValueError("Reference data not set. Call set_reference_data first.")

        column_mapping = self.create_column_mapping(current_df)

        # Data drift report
        drift_report = Report(metrics=[
            DataDriftPreset(stattest="ks", stattest_threshold=self.config.drift_threshold),
            DataQualityPreset(),
            TargetDriftPreset() if self.target_col else None,
        ])

        drift_report.run(
            reference_data=self.reference_data,
            current_data=current_df,
            column_mapping=column_mapping,
        )

        # Save report
        report_path = f"reports/{report_name}.html"
        drift_report.save_html(report_path)

        # Extract results
        results = self._extract_results(drift_report)
        logger.info(f"Drift analysis complete: {results}")
        return results

    def run_data_tests(self, current_df: pd.DataFrame) -> Dict:
        test_suite = TestSuite(tests=[
            TestShareOfDriftedColumns(lt=0.3),
            TestColumnDrift(column_name="feature_1", stattest="ks"),
            TestColumnNullShare(column_name="feature_1", lt=0.1),
            TestColumnValueRange(column_name="feature_1", gt=0, lt=1000),
        ])

        test_suite.run(
            reference_data=self.reference_data,
            current_data=current_df,
        )

        return {
            "passed": test_suite.as_dict()["summary"]["all_passed"],
            "tests": test_suite.as_dict()["tests"],
        }

    def _extract_results(self, report) -> Dict:
        report_dict = report.as_dict()
        results = {
            "drift_detected": False,
            "drifted_columns": [],
            "drift_scores": {},
        }

        for metric in report_dict.get("metrics", []):
            if metric.get("metric") == "DatasetDriftMetric":
                results["drift_detected"] = metric.get("result", {}).get("drift_detected", False)
            elif metric.get("metric") == "ColumnDriftMetric":
                col = metric.get("result", {}).get("column_name", "")
                drift_score = metric.get("result", {}).get("drift_score", 0)
                results["drifted_columns"].append(col)
                results["drift_scores"][col] = drift_score

        return results

Step 3: Prometheus Metrics Export

Export model metrics to Prometheus for real-time monitoring and alerting.

# src/monitoring/metrics_exporter.py
from prometheus_client import Gauge, Counter, Histogram, start_http_server
import time
from typing import Dict


class ModelMetricsExporter:
    def __init__(self, port: int = 8001):
        self.port = port

        # Define metrics
        self.prediction_latency = Histogram(
            "model_prediction_latency_seconds",
            "Model prediction latency in seconds",
            buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0],
        )
        self.prediction_count = Counter(
            "model_predictions_total",
            "Total number of predictions",
            ["model_version", "status"],
        )
        self.data_drift_score = Gauge(
            "model_data_drift_score",
            "Current data drift score",
            ["feature_name"],
        )
        self.model_accuracy = Gauge(
            "model_accuracy",
            "Current model accuracy",
            ["model_version"],
        )
        self.feature_missing_ratio = Gauge(
            "feature_missing_ratio",
            "Missing value ratio per feature",
            ["feature_name"],
        )

    def start(self):
        start_http_server(self.port)

    def record_prediction(self, latency: float, model_version: str, success: bool):
        self.prediction_latency.observe(latency)
        status = "success" if success else "error"
        self.prediction_count.labels(model_version=model_version, status=status).inc()

    def record_drift_scores(self, drift_scores: Dict[str, float]):
        for feature, score in drift_scores.items():
            self.data_drift_score.labels(feature_name=feature).set(score)

    def record_accuracy(self, accuracy: float, model_version: str):
        self.model_accuracy.labels(model_version=model_version).set(accuracy)

Step 4: Automated Retraining Trigger

# src/retraining/trigger.py
import logging
from typing import Dict, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum

logger = logging.getLogger(__name__)


class RetrainReason(Enum):
    DATA_DRIFT = "data_drift"
    PERFORMANCE_DEGRADATION = "performance_degradation"
    SCHEDULED = "scheduled"
    MANUAL = "manual"


@dataclass
class RetrainTriggerConfig:
    drift_threshold: float = 0.1
    accuracy_drop_threshold: float = 0.05
    min_samples_for_eval: int = 1000
    cooldown_hours: int = 24
    max_consecutive_failures: int = 3


class RetrainTrigger:
    def __init__(self, config: RetrainTriggerConfig = None):
        self.config = config or RetrainTriggerConfig()
        self.last_retrain_time: Optional[datetime] = None
        self.consecutive_failures = 0

    def should_retrain(self, monitoring_results: Dict) -> Optional[RetrainReason]:
        # Check cooldown
        if self.last_retrain_time:
            elapsed = datetime.now() - self.last_retrain_time
            if elapsed < timedelta(hours=self.config.cooldown_hours):
                return None

        # Check data drift
        if monitoring_results.get("drift_detected", False):
            logger.warning("Data drift detected - triggering retrain")
            return RetrainReason.DATA_DRIFT

        # Check performance degradation
        current_accuracy = monitoring_results.get("current_accuracy", 1.0)
        baseline_accuracy = monitoring_results.get("baseline_accuracy", 1.0)
        if baseline_accuracy - current_accuracy > self.config.accuracy_drop_threshold:
            logger.warning("Performance degradation detected - triggering retrain")
            return RetrainReason.PERFORMANCE_DEGRADATION

        # Check scheduled retrain
        if monitoring_results.get("days_since_last_train", 0) > 30:
            return RetrainReason.SCHEDULED

        return None

    def execute_retrain(self, reason: RetrainReason) -> bool:
        logger.info(f"Starting retraining due to: {reason.value}")
        self.last_retrain_time = datetime.now()

        try:
            # Trigger Airflow DAG
            from airflow.api.client.local_client import Client
            client = Client(None, None)
            client.trigger_dag(
                dag_id="model_retraining",
                conf={"reason": reason.value, "timestamp": datetime.now().isoformat()},
            )
            return True
        except Exception as e:
            self.consecutive_failures += 1
            logger.error(f"Retrain trigger failed: {e}")
            return False

Step 5: MLflow Experiment Tracking Integration

# src/tracking/mlflow_tracker.py
import mlflow
import mlflow.sklearn
from typing import Dict, Any, Optional
from datetime import datetime


class MLflowTracker:
    def __init__(self, experiment_name: str = "model_monitoring"):
        mlflow.set_experiment(experiment_name)

    def log_drift_event(self, drift_results: Dict, retrain_triggered: bool):
        with mlflow.start_run(run_name=f"drift-check-{datetime.now().strftime('%Y%m%d-%H%M')}"):
            mlflow.log_params({
                "check_type": "drift_detection",
                "drift_detected": drift_results.get("drift_detected", False),
                "num_drifted_columns": len(drift_results.get("drifted_columns", [])),
                "retrain_triggered": retrain_triggered,
            })

            for col, score in drift_results.get("drift_scores", {}).items():
                mlflow.log_metric(f"drift_score_{col}", score)

            if retrain_triggered:
                mlflow.set_tag("retrain_reason", "data_drift")

    def log_model_performance(
        self, model_version: str, metrics: Dict[str, float],
        dataset_info: Dict[str, Any]
    ):
        with mlflow.start_run(run_name=f"perf-{model_version}"):
            mlflow.log_params({
                "model_version": model_version,
                "dataset_size": dataset_info.get("size", 0),
                "evaluation_date": datetime.now().isoformat(),
            })
            mlflow.log_metrics(metrics)

    def log_retrain_event(
        self, reason: str, success: bool,
        metrics_before: Dict, metrics_after: Dict
    ):
        with mlflow.start_run(run_name=f"retrain-{datetime.now().strftime('%Y%m%d')}"):
            mlflow.log_params({
                "reason": reason,
                "success": success,
            })
            for k, v in metrics_before.items():
                mlflow.log_metric(f"before_{k}", v)
            for k, v in metrics_after.items():
                mlflow.log_metric(f"after_{k}", v)

ℹ️

Set up baseline metrics when deploying a new model. All subsequent monitoring should compare against this baseline. Store baseline metrics in a separate, immutable location.

πŸ’‘

Use sliding windows for drift detection to avoid false positives from natural data variation. A 7-day window with a 3-day step provides good stability.

Performance Metrics

MetricTargetDescription
Drift Detection Latency< 5minTime to detect drift
Retrain Pipeline< 2hrEnd-to-end retraining
False Alarm Rate< 5%Drift false positives
Model Refresh RateWeeklyScheduled retraining
Monitoring Coverage100%All production models

Interview Talking Points

  1. Drift Types: Data drift (input distribution change) vs concept drift (relationship between input and target changes). Different detection methods for each.
  2. Evidently AI: Provides comprehensive reports with statistical tests (KS test, PSI, Jensen-Shannon) for drift detection.
  3. Retraining Strategy: Automatic retraining with human approval gates prevents deploying broken models.
  4. Prometheus + Grafana: Industry-standard monitoring stack that integrates with existing infrastructure.
  5. MLflow Integration: Full lineage from monitoring alerts to retraining experiments to deployed models.
  6. Gradual Rollout: Canary deployments and A/B testing for new model versions reduce deployment risk.

⚠️

Aggressive drift thresholds lead to frequent, unnecessary retraining. Start with conservative thresholds and tighten based on observed false positive rates.

ℹ️

For online learning scenarios, consider using river or creme libraries for continuous model updates instead of batch retraining.

Advertisement