πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Model Monitoring & Drift Detection

MLOpsMonitoring⭐ Premium

Advertisement

Model Monitoring & Drift Detection

Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe

Why Model Monitoring?

Models degrade over time due to data drift, concept drift, and changing business conditions.

ℹ️

Amazon's monitoring systems detect model degradation within 15 minutes, triggering automated rollback.

Drift Detection

# drift_detection.py
import numpy as np
from scipy import stats
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
import warnings

warnings.filterwarnings('ignore')

class DriftType(Enum):
    DATA_DRIFT = "data_drift"
    CONCEPT_DRIFT = "concept_drift"
    PREDICTION_DRIFT = "prediction_drift"
    FEATURE_DRIFT = "feature_drift"

class DriftSeverity(Enum):
    NONE = "none"
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"

@dataclass
class DriftAlert:
    drift_type: DriftType
    feature_name: Optional[str]
    severity: DriftSeverity
    score: float
    threshold: float
    details: str
    timestamp: str

class DriftDetector:
    def __init__(self, reference_data: np.ndarray, significance_level: float = 0.05):
        self.reference_data = reference_data
        self.significance_level = significance_level
        self.reference_stats = self._compute_stats(reference_data)

    def _compute_stats(self, data: np.ndarray) -> Dict:
        return {
            "mean": np.mean(data, axis=0),
            "std": np.std(data, axis=0),
            "median": np.median(data, axis=0),
            "q25": np.percentile(data, 25, axis=0),
            "q75": np.percentile(data, 75, axis=0),
        }

    def detect_ks_drift(self, current_data: np.ndarray) -> Tuple[bool, float, float]:
        n_features = self.reference_data.shape[1]
        drift_detected = False
        min_p_value = 1.0

        for i in range(n_features):
            statistic, p_value = stats.ks_2samp(
                self.reference_data[:, i],
                current_data[:, i]
            )
            if p_value < self.significance_level:
                drift_detected = True
            min_p_value = min(min_p_value, p_value)

        return drift_detected, min_p_value, statistic

    def detect_psi(self, current_data: np.ndarray, n_bins: int = 10) -> Tuple[bool, float]:
        psi_values = []
        n_features = self.reference_data.shape[1]

        for i in range(n_features):
            ref_hist, bin_edges = np.histogram(self.reference_data[:, i], bins=n_bins, density=True)
            curr_hist, _ = np.histogram(current_data[:, i], bins=bin_edges, density=True)

            ref_hist = np.clip(ref_hist, 1e-6, None)
            curr_hist = np.clip(curr_hist, 1e-6, None)

            psi = np.sum((curr_hist - ref_hist) * np.log(curr_hist / ref_hist))
            psi_values.append(psi)

        avg_psi = np.mean(psi_values)
        drift_detected = avg_psi > 0.2
        return drift_detected, avg_psi

    def detect_mmd(self, current_data: np.ndarray, gamma: float = 1.0) -> Tuple[bool, float]:
        from sklearn.metrics.pairwise import rbf_kernel

        n_ref = min(500, len(self.reference_data))
        n_curr = min(500, len(current_data))

        ref_sample = self.reference_data[np.random.choice(len(self.reference_data), n_ref)]
        curr_sample = current_data[np.random.choice(len(current_data), n_curr)]

        K_rr = rbf_kernel(ref_sample, gamma=gamma)
        K_cc = rbf_kernel(curr_sample, gamma=gamma)
        K_rc = rbf_kernel(ref_sample, curr_sample, gamma=gamma)

        mmd2 = np.mean(K_rr) + np.mean(K_cc) - 2 * np.mean(K_rc)
        mmd = np.sqrt(max(0, mmd2))

        threshold = 0.1
        drift_detected = mmd > threshold
        return drift_detected, mmd

    def detect_concept_drift(
        self,
        reference_predictions: np.ndarray,
        reference_labels: np.ndarray,
        current_predictions: np.ndarray,
        current_labels: np.ndarray
    ) -> Tuple[bool, float]:
        ref_errors = (reference_predictions > 0.5).astype(int) != reference_labels
        curr_errors = (current_predictions > 0.5).astype(int) != current_labels

        ref_error_rate = np.mean(ref_errors)
        curr_error_rate = np.mean(curr_errors)

        error_ratio = curr_error_rate / max(ref_error_rate, 1e-6)
        drift_detected = error_ratio > 1.5
        return drift_detected, error_ratio


class ModelMonitor:
    def __init__(self):
        self.drift_detectors: Dict[str, DriftDetector] = {}
        self.alerts: List[DriftAlert] = []
        self.metrics_history: List[Dict] = []

    def register_detector(self, name: str, detector: DriftDetector):
        self.drift_detectors[name] = detector

    def check_drift(self, feature_name: str, current_data: np.ndarray) -> List[DriftAlert]:
        alerts = []
        if feature_name not in self.drift_detectors:
            return alerts

        detector = self.drift_detectors[feature_name]

        ks_drift, ks_p_value, ks_stat = detector.detect_ks_drift(current_data)
        if ks_drift:
            severity = DriftSeverity.HIGH if ks_p_value < 0.01 else DriftSeverity.MEDIUM
            alerts.append(DriftAlert(
                drift_type=DriftType.DATA_DRIFT,
                feature_name=feature_name,
                severity=severity,
                score=ks_stat,
                threshold=0.1,
                details=f"KS test p-value: {ks_p_value:.4f}",
                timestamp=datetime.now().isoformat()
            ))

        psi_drift, psi_score = detector.detect_psi(current_data)
        if psi_drift:
            severity = DriftSeverity.HIGH if psi_score > 0.5 else DriftSeverity.MEDIUM
            alerts.append(DriftAlert(
                drift_type=DriftType.DATA_DRIFT,
                feature_name=feature_name,
                severity=severity,
                score=psi_score,
                threshold=0.2,
                details=f"PSI score: {psi_score:.4f}",
                timestamp=datetime.now().isoformat()
            ))

        self.alerts.extend(alerts)
        return alerts

    def record_prediction(self, prediction: float, actual: Optional[float] = None):
        self.metrics_history.append({
            "prediction": prediction,
            "actual": actual,
            "timestamp": datetime.now().isoformat()
        })

    def get_prediction_drift(self, window_size: int = 100) -> Dict:
        if len(self.metrics_history) < window_size * 2:
            return {"drift_detected": False}

        recent = [m["prediction"] for m in self.metrics_history[-window_size:]]
        previous = [m["prediction"] for m in self.metrics_history[-window_size*2:-window_size]]

        stat, p_value = stats.ks_2samp(previous, recent)
        return {
            "drift_detected": p_value < 0.05,
            "ks_statistic": stat,
            "p_value": p_value,
            "recent_mean": np.mean(recent),
            "previous_mean": np.mean(previous)
        }

    def get_alerts_summary(self) -> Dict:
        summary = {"total": len(self.alerts), "by_severity": {}, "by_type": {}}
        for alert in self.alerts:
            summary["by_severity"][alert.severity.value] = summary["by_severity"].get(alert.severity.value, 0) + 1
            summary["by_type"][alert.drift_type.value] = summary["by_type"].get(alert.drift_type.value, 0) + 1
        return summary


# Usage
reference_data = np.random.randn(1000, 5)
current_data = np.random.randn(1000, 5) * 1.2 + 0.5

detector = DriftDetector(reference_data)
monitor = ModelMonitor()
monitor.register_detector("feature_0", detector)

alerts = monitor.check_drift("feature_0", current_data)
for alert in alerts:
    print(f"Alert: {alert.drift_type.value} - {alert.severity.value} - {alert.details}")

Monitoring Dashboard

# monitoring_metrics.py
import time
from typing import Dict, List
from dataclasses import dataclass
from datetime import datetime
import json

@dataclass
class PredictionMetric:
    prediction: float
    latency_ms: float
    timestamp: datetime
    model_version: str
    features: Dict[str, float]

class MetricsCollector:
    def __init__(self):
        self.predictions: List[PredictionMetric] = []
        self.counters: Dict[str, int] = {}
        self.histograms: Dict[str, List[float]] = {}

    def record_prediction(self, metric: PredictionMetric):
        self.predictions.append(metric)
        self._update_counters(metric)
        self._update_histograms(metric)

    def _update_counters(self, metric: PredictionMetric):
        self.counters["total_predictions"] = self.counters.get("total_predictions", 0) + 1
        self.counters[f"model_{metric.model_version}"] = self.counters.get(f"model_{metric.model_version}", 0) + 1

    def _update_histograms(self, metric: PredictionMetric):
        if "latency" not in self.histograms:
            self.histograms["latency"] = []
        self.histograms["latency"].append(metric.latency_ms)

    def get_metrics_summary(self) -> Dict:
        if not self.predictions:
            return {}

        latencies = [p.latency_ms for p in self.predictions]
        return {
            "total_predictions": len(self.predictions),
            "avg_latency": sum(latencies) / len(latencies),
            "p50_latency": sorted(latencies)[len(latencies) // 2],
            "p99_latency": sorted(latencies)[int(len(latencies) * 0.99)],
            "throughput_rps": len(self.predictions) / max(1, (self.predictions[-1].timestamp - self.predictions[0].timestamp).seconds),
        }

    def export_prometheus_metrics(self) -> str:
        lines = []
        lines.append(f"# HELP ml_predictions_total Total predictions")
        lines.append(f"# TYPE ml_predictions_total counter")
        lines.append(f"ml_predictions_total {self.counters.get('total_predictions', 0)}")

        latencies = [p.latency_ms for p in self.predictions]
        if latencies:
            lines.append(f"# HELP ml_prediction_latency_ms Prediction latency")
            lines.append(f"# TYPE ml_prediction_latency_ms histogram")
            for percentile in [50, 90, 95, 99]:
                lines.append(f"ml_prediction_latency_ms{{percentile=\"{percentile}\"}} {sorted(latencies)[int(len(latencies) * percentile / 100)]}")

        return "\n".join(lines)

Follow-Up Questions

  1. How do you distinguish between data drift and concept drift?
  2. What thresholds should trigger automated retraining?
  3. How would you implement monitoring for streaming ML models?
  4. What are the trade-offs between different drift detection algorithms?

Advertisement