πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Drift Detection

AIOps CoreDrift Analysis🟒 Free Lesson

Advertisement

Drift Detection

Drift detection identifies changes in data distribution (data drift) or the relationship between inputs and outputs (concept drift) that can degrade model performance over time.

Types of Drift

  • Data Drift: Changes in input feature distributions
  • Concept Drift: Changes in the mapping between inputs and outputs
  • Label Drift: Changes in the distribution of target variables
  • Upstream Drift: Changes in data sources or collection methods

Drift Detection Architecture

Population Stability Index (PSI)

Implementation

import numpy as np
import pandas as pd
from typing import List, Tuple

class PSICalculator:
    def __init__(self, n_bins=10, clip_min=0.001):
        self.n_bins = n_bins
        self.clip_min = clip_min
    
    def calculate_psi(self, reference: np.ndarray, current: np.ndarray) -> float:
        """Calculate Population Stability Index"""
        # Create bins based on reference data
        bins = np.percentile(reference, np.linspace(0, 100, self.n_bins + 1))
        bins = np.unique(bins)  # Remove duplicate edges
        
        # Calculate proportions
        ref_proportions = self._calculate_proportions(reference, bins)
        curr_proportions = self._calculate_proportions(current, bins)
        
        # Clip to avoid division by zero
        ref_proportions = np.clip(ref_proportions, self.clip_min, None)
        curr_proportions = np.clip(curr_proportions, self.clip_min, None)
        
        # Calculate PSI
        psi = np.sum((curr_proportions - ref_proportions) * 
                     np.log(curr_proportions / ref_proportions))
        
        return psi
    
    def _calculate_proportions(self, data: np.ndarray, bins: np.ndarray) -> np.ndarray:
        """Calculate proportions for each bin"""
        proportions = []
        for i in range(len(bins) - 1):
            if i == 0:
                count = np.sum(data <= bins[i + 1])
            elif i == len(bins) - 2:
                count = np.sum(data > bins[i])
            else:
                count = np.sum((data > bins[i]) & (data <= bins[i + 1]))
            proportions.append(count / len(data))
        
        return np.array(proportions)
    
    def calculate_psi_per_feature(self, reference_df: pd.DataFrame, 
                                  current_df: pd.DataFrame) -> dict:
        """Calculate PSI for each feature"""
        psi_results = {}
        
        for column in reference_df.select_dtypes(include=[np.number]).columns:
            if column in current_df.columns:
                psi = self.calculate_psi(
                    reference_df[column].dropna().values,
                    current_df[column].dropna().values
                )
                psi_results[column] = {
                    "psi": psi,
                    "drift_detected": psi > 0.1,  # Common threshold
                    "severity": self._get_drift_severity(psi)
                }
        
        return psi_results
    
    def _get_drift_severity(self, psi: float) -> str:
        """Determine drift severity based on PSI value"""
        if psi < 0.1:
            return "none"
        elif psi < 0.2:
            return "moderate"
        else:
            return "significant"

PSI Interpretation

PSI ValueInterpretationAction Required
< 0.10No significant driftMonitor
0.10 - 0.20Moderate driftInvestigate
> 0.20Significant driftRetrain model

Kolmogorov-Smirnov Test

Implementation

from scipy import stats
import numpy as np
from typing import Dict, List

class KSTestDetector:
    def __init__(self, significance_level=0.05):
        self.significance_level = significance_level
    
    def ks_test(self, reference: np.ndarray, current: np.ndarray) -> Dict:
        """Perform Kolmogorov-Smirnov test"""
        statistic, p_value = stats.ks_2samp(reference, current)
        
        return {
            "statistic": statistic,
            "p_value": p_value,
            "drift_detected": p_value < self.significance_level,
            "effect_size": self._calculate_effect_size(reference, current)
        }
    
    def _calculate_effect_size(self, reference: np.ndarray, current: np.ndarray) -> float:
        """Calculate Cohen's d effect size"""
        pooled_std = np.sqrt((np.std(reference)**2 + np.std(current)**2) / 2)
        if pooled_std == 0:
            return 0.0
        return abs(np.mean(reference) - np.mean(current)) / pooled_std
    
    def multi_ks_test(self, reference_df: pd.DataFrame, 
                      current_df: pd.DataFrame) -> Dict[str, Dict]:
        """Perform KS test for multiple features"""
        results = {}
        
        for column in reference_df.select_dtypes(include=[np.number]).columns:
            if column in current_df.columns:
                ref_data = reference_df[column].dropna().values
                curr_data = current_df[column].dropna().values
                
                if len(ref_data) > 0 and len(curr_data) > 0:
                    results[column] = self.ks_test(ref_data, curr_data)
        
        return results
    
    def detect_drift_pattern(self, reference: np.ndarray, 
                            current: np.ndarray) -> Dict:
        """Detect specific drift patterns"""
        # Calculate statistical moments
        ref_mean, ref_std = np.mean(reference), np.std(reference)
        curr_mean, curr_std = np.mean(current), np.std(current)
        
        patterns = []
        
        # Mean shift
        if abs(curr_mean - ref_mean) > 2 * ref_std:
            patterns.append("mean_shift")
        
        # Variance change
        if abs(curr_std - ref_std) > 0.5 * ref_std:
            patterns.append("variance_change")
        
        # Distribution shape change
        ref_skew = stats.skew(reference)
        curr_skew = stats.skew(current)
        if abs(curr_skew - ref_skew) > 0.5:
            patterns.append("skewness_change")
        
        return {
            "patterns": patterns,
            "reference_stats": {"mean": ref_mean, "std": ref_std, "skew": ref_skew},
            "current_stats": {"mean": curr_mean, "std": curr_std, "skew": curr_skew}
        }

Mathematical Foundation

PSI Formula

Population Stability Index

PSI=βˆ‘i=1N(Piβˆ’Qi)β‹…ln⁑(PiQi)PSI = \sum_{i=1}^{N}(P_i - Q_i) \cdot \ln\left(\frac{P_i}{Q_i}\right)

Where:

  • ( P_i ) = proportion of observations in bin ( i ) for reference distribution
  • ( Q_i ) = proportion of observations in bin ( i ) for current distribution
  • ( N ) = number of bins

KS Statistic

Kolmogorov-Smirnov Statistic

DKS=sup⁑x∣Fref(x)βˆ’Fcurr(x)∣D_{KS} = \sup_x |F_{ref}(x) - F_{curr}(x)|

Where:

  • ( F_{ref}(x) ) = cumulative distribution function of reference data
  • ( F_{curr}(x) ) = cumulative distribution function of current data

Jensen-Shannon Divergence

Jensen-Shannon Divergence

JSD(Pβˆ₯Q)=12DKL(Pβˆ₯M)+12DKL(Qβˆ₯M)JSD(P \| Q) = \frac{1}{2}D_{KL}(P \| M) + \frac{1}{2}D_{KL}(Q \| M)

Where ( M = \frac{1}{2}(P + Q) ) and ( D_{KL} ) is the Kullback-Leibler divergence:

KL Divergence

DKL(Pβˆ₯Q)=βˆ‘iP(i)log⁑(P(i)Q(i))D_{KL}(P \| Q) = \sum_{i} P(i) \log\left(\frac{P(i)}{Q(i)}\right)

Wasserstein Distance

Wasserstein Distance

W1(P,Q)=βˆ«βˆ’βˆžβˆžβˆ£FP(x)βˆ’FQ(x)∣dxW_1(P, Q) = \int_{-\infty}^{\infty} |F_P(x) - F_Q(x)| dx

Advanced Drift Detection Methods

ADWIN (Adaptive Windowing)

from collections import deque
import numpy as np

class ADWINDetector:
    def __init__(self, delta=0.002):
        self.delta = delta
        self.window = deque()
        self.total = 0
        self.variance = 0
    
    def update(self, value):
        """Update with new observation"""
        self.window.append(value)
        self.total += value
        
        # Calculate variance incrementally
        if len(self.window) > 1:
            mean = self.total / len(self.window)
            self.variance = sum((x - mean)**2 for x in self.window) / len(self.window)
        
        # Check for drift
        return self._check_drift()
    
    def _check_drift(self):
        """Check if drift occurred"""
        if len(self.window) < 10:
            return False
        
        # Split window
        n = len(self.window)
        mid = n // 2
        
        left_half = list(self.window)[:mid]
        right_half = list(self.window)[mid:]
        
        # Calculate means
        mean_left = np.mean(left_half)
        mean_right = np.mean(right_half)
        
        # Calculate variance
        var_left = np.var(left_half)
        var_right = np.var(right_half)
        
        # ADWIN test
        delta = self.delta
        n_left = len(left_half)
        n_right = len(right_half)
        
        # Calculate epsilon
        epsilon = np.sqrt((1.0 / (2 * n_left) + 1.0 / (2 * n_right)) * 
                         2 * np.log(2.0 / delta))
        
        # Check if means differ significantly
        if abs(mean_left - mean_right) > epsilon:
            # Drift detected, shrink window
            self.window = deque(right_half)
            self.total = sum(right_half)
            return True
        
        return False

Page-Hinkley Test

class PageHinkleyDetector:
    def __init__(self, threshold=50, alpha=0.9999):
        self.threshold = threshold
        self.alpha = alpha
        self.sum = 0
        self.x_mean = 0
        self.min_sum = float('inf')
        self.sample_count = 0
    
    def update(self, value):
        """Update with new observation"""
        self.sample_count += 1
        self.x_mean += (value - self.x_mean) / self.sample_count
        self.sum += value - self.x_mean - self.alpha
        
        self.min_sum = min(self.min_sum, self.sum)
        
        # Check for drift
        if self.sum - self.min_sum > self.threshold:
            self._reset()
            return True
        
        return False
    
    def _reset(self):
        """Reset detector"""
        self.sum = 0
        self.min_sum = float('inf')
        self.sample_count = 0

Drift Visualization

import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict

class DriftVisualizer:
    def __init__(self):
        self.fig, self.axes = plt.subplots(2, 2, figsize=(12, 10))
    
    def plot_distribution_comparison(self, reference: np.ndarray, 
                                   current: np.ndarray, feature_name: str):
        """Plot distribution comparison"""
        ax = self.axes[0, 0]
        
        sns.histplot(reference, alpha=0.5, label='Reference', ax=ax, color='blue')
        sns.histplot(current, alpha=0.5, label='Current', ax=ax, color='red')
        
        ax.set_title(f'Distribution Comparison: {feature_name}')
        ax.legend()
    
    def plot_psi_trend(self, psi_values: list, feature_name: str):
        """Plot PSI trend over time"""
        ax = self.axes[0, 1]
        
        ax.plot(psi_values, marker='o')
        ax.axhline(y=0.1, color='orange', linestyle='--', label='Moderate Drift')
        ax.axhline(y=0.2, color='red', linestyle='--', label='Significant Drift')
        
        ax.set_title(f'PSI Trend: {feature_name}')
        ax.set_xlabel('Time Window')
        ax.set_ylabel('PSI Value')
        ax.legend()
    
    def plot_ks_test_results(self, ks_results: Dict):
        """Plot KS test results"""
        ax = self.axes[1, 0]
        
        features = list(ks_results.keys())
        statistics = [ks_results[f]['statistic'] for f in features]
        p_values = [ks_results[f]['p_value'] for f in features]
        
        x = range(len(features))
        ax.bar(x, statistics, alpha=0.6, label='KS Statistic')
        ax.scatter(x, p_values, color='red', label='P-value')
        
        ax.set_xticks(x)
        ax.set_xticklabels(features, rotation=45)
        ax.set_title('KS Test Results')
        ax.legend()
    
    def plot_drift_heatmap(self, drift_scores: Dict[str, Dict]):
        """Plot drift scores heatmap"""
        ax = self.axes[1, 1]
        
        features = list(drift_scores.keys())
        metrics = ['psi', 'ks_statistic', 'effect_size']
        
        data = []
        for feature in features:
            row = [drift_scores[feature].get(m, 0) for m in metrics]
            data.append(row)
        
        sns.heatmap(data, annot=True, xticklabels=metrics, 
                   yticklabels=features, ax=ax)
        ax.set_title('Drift Scores Heatmap')
    
    def show(self):
        """Display all plots"""
        plt.tight_layout()
        plt.show()

Best Practices

1. Reference Data Management

class ReferenceDataManager:
    def __init__(self, storage_path):
        self.storage_path = storage_path
    
    def update_reference(self, new_data, version):
        """Update reference data"""
        # Validate new reference data
        if not self._validate_reference(new_data):
            raise ValueError("Invalid reference data")
        
        # Store with version
        self._store_reference(new_data, version)
        
        # Update metadata
        self._update_metadata(version)
    
    def _validate_reference(self, data):
        """Validate reference data quality"""
        # Check for sufficient samples
        if len(data) < 1000:
            return False
        
        # Check for missing values
        if data.isnull().any().any():
            return False
        
        return True

2. Automated Drift Response

class DriftResponseManager:
    def __init__(self, model_registry, retraining_pipeline):
        self.model_registry = model_registry
        self.retraining_pipeline = retraining_pipeline
    
    def handle_drift(self, drift_results):
        """Handle detected drift"""
        severity = self._assess_severity(drift_results)
        
        if severity == "significant":
            # Trigger automatic retraining
            self.retraining_pipeline.trigger_retraining()
            
            # Notify team
            self._send_alert("Significant drift detected", severity)
        
        elif severity == "moderate":
            # Increase monitoring frequency
            self._increase_monitoring_frequency()
            
            # Log for investigation
            self._log_drift_incident(drift_results)
    
    def _assess_severity(self, drift_results):
        """Assess drift severity"""
        max_psi = max(r['psi'] for r in drift_results.values())
        
        if max_psi > 0.2:
            return "significant"
        elif max_psi > 0.1:
            return "moderate"
        else:
            return "minor"

Summary

Drift detection is crucial for maintaining model performance in production. By implementing comprehensive drift detection using statistical tests like PSI and KS test, organizations can identify distribution changes early and take corrective actions to maintain model accuracy and reliability.

⭐

Premium Content

Drift Detection

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert AI Ops & LLM Ops Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement