Drift Detection
Drift detection identifies changes in data distribution (data drift) or the relationship between inputs and outputs (concept drift) that can degrade model performance over time.
Types of Drift
- Data Drift: Changes in input feature distributions
- Concept Drift: Changes in the mapping between inputs and outputs
- Label Drift: Changes in the distribution of target variables
- Upstream Drift: Changes in data sources or collection methods
Drift Detection Architecture
Population Stability Index (PSI)
Implementation
import numpy as np
import pandas as pd
from typing import List, Tuple
class PSICalculator:
def __init__(self, n_bins=10, clip_min=0.001):
self.n_bins = n_bins
self.clip_min = clip_min
def calculate_psi(self, reference: np.ndarray, current: np.ndarray) -> float:
"""Calculate Population Stability Index"""
# Create bins based on reference data
bins = np.percentile(reference, np.linspace(0, 100, self.n_bins + 1))
bins = np.unique(bins) # Remove duplicate edges
# Calculate proportions
ref_proportions = self._calculate_proportions(reference, bins)
curr_proportions = self._calculate_proportions(current, bins)
# Clip to avoid division by zero
ref_proportions = np.clip(ref_proportions, self.clip_min, None)
curr_proportions = np.clip(curr_proportions, self.clip_min, None)
# Calculate PSI
psi = np.sum((curr_proportions - ref_proportions) *
np.log(curr_proportions / ref_proportions))
return psi
def _calculate_proportions(self, data: np.ndarray, bins: np.ndarray) -> np.ndarray:
"""Calculate proportions for each bin"""
proportions = []
for i in range(len(bins) - 1):
if i == 0:
count = np.sum(data <= bins[i + 1])
elif i == len(bins) - 2:
count = np.sum(data > bins[i])
else:
count = np.sum((data > bins[i]) & (data <= bins[i + 1]))
proportions.append(count / len(data))
return np.array(proportions)
def calculate_psi_per_feature(self, reference_df: pd.DataFrame,
current_df: pd.DataFrame) -> dict:
"""Calculate PSI for each feature"""
psi_results = {}
for column in reference_df.select_dtypes(include=[np.number]).columns:
if column in current_df.columns:
psi = self.calculate_psi(
reference_df[column].dropna().values,
current_df[column].dropna().values
)
psi_results[column] = {
"psi": psi,
"drift_detected": psi > 0.1, # Common threshold
"severity": self._get_drift_severity(psi)
}
return psi_results
def _get_drift_severity(self, psi: float) -> str:
"""Determine drift severity based on PSI value"""
if psi < 0.1:
return "none"
elif psi < 0.2:
return "moderate"
else:
return "significant"
PSI Interpretation
| PSI Value | Interpretation | Action Required |
|---|---|---|
| < 0.10 | No significant drift | Monitor |
| 0.10 - 0.20 | Moderate drift | Investigate |
| > 0.20 | Significant drift | Retrain model |
Kolmogorov-Smirnov Test
Implementation
from scipy import stats
import numpy as np
from typing import Dict, List
class KSTestDetector:
def __init__(self, significance_level=0.05):
self.significance_level = significance_level
def ks_test(self, reference: np.ndarray, current: np.ndarray) -> Dict:
"""Perform Kolmogorov-Smirnov test"""
statistic, p_value = stats.ks_2samp(reference, current)
return {
"statistic": statistic,
"p_value": p_value,
"drift_detected": p_value < self.significance_level,
"effect_size": self._calculate_effect_size(reference, current)
}
def _calculate_effect_size(self, reference: np.ndarray, current: np.ndarray) -> float:
"""Calculate Cohen's d effect size"""
pooled_std = np.sqrt((np.std(reference)**2 + np.std(current)**2) / 2)
if pooled_std == 0:
return 0.0
return abs(np.mean(reference) - np.mean(current)) / pooled_std
def multi_ks_test(self, reference_df: pd.DataFrame,
current_df: pd.DataFrame) -> Dict[str, Dict]:
"""Perform KS test for multiple features"""
results = {}
for column in reference_df.select_dtypes(include=[np.number]).columns:
if column in current_df.columns:
ref_data = reference_df[column].dropna().values
curr_data = current_df[column].dropna().values
if len(ref_data) > 0 and len(curr_data) > 0:
results[column] = self.ks_test(ref_data, curr_data)
return results
def detect_drift_pattern(self, reference: np.ndarray,
current: np.ndarray) -> Dict:
"""Detect specific drift patterns"""
# Calculate statistical moments
ref_mean, ref_std = np.mean(reference), np.std(reference)
curr_mean, curr_std = np.mean(current), np.std(current)
patterns = []
# Mean shift
if abs(curr_mean - ref_mean) > 2 * ref_std:
patterns.append("mean_shift")
# Variance change
if abs(curr_std - ref_std) > 0.5 * ref_std:
patterns.append("variance_change")
# Distribution shape change
ref_skew = stats.skew(reference)
curr_skew = stats.skew(current)
if abs(curr_skew - ref_skew) > 0.5:
patterns.append("skewness_change")
return {
"patterns": patterns,
"reference_stats": {"mean": ref_mean, "std": ref_std, "skew": ref_skew},
"current_stats": {"mean": curr_mean, "std": curr_std, "skew": curr_skew}
}
Mathematical Foundation
PSI Formula
Population Stability Index
Where:
- ( P_i ) = proportion of observations in bin ( i ) for reference distribution
- ( Q_i ) = proportion of observations in bin ( i ) for current distribution
- ( N ) = number of bins
KS Statistic
Kolmogorov-Smirnov Statistic
Where:
- ( F_{ref}(x) ) = cumulative distribution function of reference data
- ( F_{curr}(x) ) = cumulative distribution function of current data
Jensen-Shannon Divergence
Jensen-Shannon Divergence
Where ( M = \frac{1}{2}(P + Q) ) and ( D_{KL} ) is the Kullback-Leibler divergence:
KL Divergence
Wasserstein Distance
Wasserstein Distance
Advanced Drift Detection Methods
ADWIN (Adaptive Windowing)
from collections import deque
import numpy as np
class ADWINDetector:
def __init__(self, delta=0.002):
self.delta = delta
self.window = deque()
self.total = 0
self.variance = 0
def update(self, value):
"""Update with new observation"""
self.window.append(value)
self.total += value
# Calculate variance incrementally
if len(self.window) > 1:
mean = self.total / len(self.window)
self.variance = sum((x - mean)**2 for x in self.window) / len(self.window)
# Check for drift
return self._check_drift()
def _check_drift(self):
"""Check if drift occurred"""
if len(self.window) < 10:
return False
# Split window
n = len(self.window)
mid = n // 2
left_half = list(self.window)[:mid]
right_half = list(self.window)[mid:]
# Calculate means
mean_left = np.mean(left_half)
mean_right = np.mean(right_half)
# Calculate variance
var_left = np.var(left_half)
var_right = np.var(right_half)
# ADWIN test
delta = self.delta
n_left = len(left_half)
n_right = len(right_half)
# Calculate epsilon
epsilon = np.sqrt((1.0 / (2 * n_left) + 1.0 / (2 * n_right)) *
2 * np.log(2.0 / delta))
# Check if means differ significantly
if abs(mean_left - mean_right) > epsilon:
# Drift detected, shrink window
self.window = deque(right_half)
self.total = sum(right_half)
return True
return False
Page-Hinkley Test
class PageHinkleyDetector:
def __init__(self, threshold=50, alpha=0.9999):
self.threshold = threshold
self.alpha = alpha
self.sum = 0
self.x_mean = 0
self.min_sum = float('inf')
self.sample_count = 0
def update(self, value):
"""Update with new observation"""
self.sample_count += 1
self.x_mean += (value - self.x_mean) / self.sample_count
self.sum += value - self.x_mean - self.alpha
self.min_sum = min(self.min_sum, self.sum)
# Check for drift
if self.sum - self.min_sum > self.threshold:
self._reset()
return True
return False
def _reset(self):
"""Reset detector"""
self.sum = 0
self.min_sum = float('inf')
self.sample_count = 0
Drift Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict
class DriftVisualizer:
def __init__(self):
self.fig, self.axes = plt.subplots(2, 2, figsize=(12, 10))
def plot_distribution_comparison(self, reference: np.ndarray,
current: np.ndarray, feature_name: str):
"""Plot distribution comparison"""
ax = self.axes[0, 0]
sns.histplot(reference, alpha=0.5, label='Reference', ax=ax, color='blue')
sns.histplot(current, alpha=0.5, label='Current', ax=ax, color='red')
ax.set_title(f'Distribution Comparison: {feature_name}')
ax.legend()
def plot_psi_trend(self, psi_values: list, feature_name: str):
"""Plot PSI trend over time"""
ax = self.axes[0, 1]
ax.plot(psi_values, marker='o')
ax.axhline(y=0.1, color='orange', linestyle='--', label='Moderate Drift')
ax.axhline(y=0.2, color='red', linestyle='--', label='Significant Drift')
ax.set_title(f'PSI Trend: {feature_name}')
ax.set_xlabel('Time Window')
ax.set_ylabel('PSI Value')
ax.legend()
def plot_ks_test_results(self, ks_results: Dict):
"""Plot KS test results"""
ax = self.axes[1, 0]
features = list(ks_results.keys())
statistics = [ks_results[f]['statistic'] for f in features]
p_values = [ks_results[f]['p_value'] for f in features]
x = range(len(features))
ax.bar(x, statistics, alpha=0.6, label='KS Statistic')
ax.scatter(x, p_values, color='red', label='P-value')
ax.set_xticks(x)
ax.set_xticklabels(features, rotation=45)
ax.set_title('KS Test Results')
ax.legend()
def plot_drift_heatmap(self, drift_scores: Dict[str, Dict]):
"""Plot drift scores heatmap"""
ax = self.axes[1, 1]
features = list(drift_scores.keys())
metrics = ['psi', 'ks_statistic', 'effect_size']
data = []
for feature in features:
row = [drift_scores[feature].get(m, 0) for m in metrics]
data.append(row)
sns.heatmap(data, annot=True, xticklabels=metrics,
yticklabels=features, ax=ax)
ax.set_title('Drift Scores Heatmap')
def show(self):
"""Display all plots"""
plt.tight_layout()
plt.show()
Best Practices
1. Reference Data Management
class ReferenceDataManager:
def __init__(self, storage_path):
self.storage_path = storage_path
def update_reference(self, new_data, version):
"""Update reference data"""
# Validate new reference data
if not self._validate_reference(new_data):
raise ValueError("Invalid reference data")
# Store with version
self._store_reference(new_data, version)
# Update metadata
self._update_metadata(version)
def _validate_reference(self, data):
"""Validate reference data quality"""
# Check for sufficient samples
if len(data) < 1000:
return False
# Check for missing values
if data.isnull().any().any():
return False
return True
2. Automated Drift Response
class DriftResponseManager:
def __init__(self, model_registry, retraining_pipeline):
self.model_registry = model_registry
self.retraining_pipeline = retraining_pipeline
def handle_drift(self, drift_results):
"""Handle detected drift"""
severity = self._assess_severity(drift_results)
if severity == "significant":
# Trigger automatic retraining
self.retraining_pipeline.trigger_retraining()
# Notify team
self._send_alert("Significant drift detected", severity)
elif severity == "moderate":
# Increase monitoring frequency
self._increase_monitoring_frequency()
# Log for investigation
self._log_drift_incident(drift_results)
def _assess_severity(self, drift_results):
"""Assess drift severity"""
max_psi = max(r['psi'] for r in drift_results.values())
if max_psi > 0.2:
return "significant"
elif max_psi > 0.1:
return "moderate"
else:
return "minor"
Summary
Drift detection is crucial for maintaining model performance in production. By implementing comprehensive drift detection using statistical tests like PSI and KS test, organizations can identify distribution changes early and take corrective actions to maintain model accuracy and reliability.