NLP Monitoring and Observability
Production NLP systems require continuous monitoring to detect performance degradation, data drift, and quality issues before they impact users.
Monitoring Dimensions
| Category | Metrics | Alert Threshold |
|---|---|---|
| Performance | Latency P50/P95/P99 | P99 > 500ms |
| Throughput | Requests per second | RPS < expected |
| Quality | Accuracy, F1, perplexity | Drop > 2% |
| Drift | Data distribution shift | KL > 0.1 |
| Errors | Error rate, timeout rate | Error > 1% |
| Resources | CPU, GPU, memory usage | Usage > 90% |
Data Drift Detection
Data drift occurs when the distribution of incoming data differs from training data.
DfDrift Metrics
Population Stability Index (PSI):
Kolmogorov-Smirnov Statistic:
Wasserstein Distance:
import numpy as np
from scipy import stats
from collections import Counter
class DriftDetector:
def __init__(self, reference_data, bins=10):
self.reference_data = reference_data
self.bins = bins
self.reference_hist, self.bin_edges = np.histogram(
reference_data, bins=bins, density=True
)
def compute_psi(self, current_data, bins=None):
"""Compute Population Stability Index."""
bins = bins or self.bins
current_hist, _ = np.histogram(current_data, bins=self.bin_edges, density=True)
# Avoid division by zero
current_hist = np.clip(current_hist, 1e-10, None)
ref_hist = np.clip(self.reference_hist, 1e-10, None)
psi = np.sum((current_hist - ref_hist) * np.log(current_hist / ref_hist))
return psi
def compute_ks_test(self, current_data):
"""Compute Kolmogorov-Smirnov test statistic."""
statistic, p_value = stats.ks_2samp(self.reference_data, current_data)
return {
"statistic": statistic,
"p_value": p_value,
"is_drifted": p_value < 0.05,
}
def compute_wasserstein(self, current_data):
"""Compute Wasserstein distance."""
distance = stats.wasserstein_distance(self.reference_data, current_data)
return distance
def detect_text_drift(self, reference_texts, current_texts):
"""Detect drift in text data using token distribution."""
ref_tokens = [token for text in reference_texts for token in text.lower().split()]
cur_tokens = [token for text in current_texts for token in text.lower().split()]
ref_vocab = Counter(ref_tokens)
cur_vocab = Counter(cur_tokens)
# All unique tokens
all_tokens = set(ref_vocab.keys()) | set(cur_vocab.keys())
ref_probs = np.array([ref_vocab.get(t, 0) / len(ref_tokens) for t in all_tokens])
cur_probs = np.array([cur_vocab.get(t, 0) / len(cur_tokens) for t in all_tokens])
# Jensen-Shannon divergence
m = 0.5 * (ref_probs + cur_probs)
kl_ref = np.sum(ref_probs * np.log(ref_probs / (m + 1e-10) + 1e-10))
kl_cur = np.sum(cur_probs * np.log(cur_probs / (m + 1e-10) + 1e-10))
js_divergence = 0.5 * (kl_ref + kl_cur)
return {
"js_divergence": js_divergence,
"is_drifted": js_divergence > 0.1,
}
# Usage
detector = DriftDetector(reference_data=train_embeddings)
psi = detector.compute_psi(current_data=inference_embeddings)
ks = detector.compute_ks_test(current_data=inference_embeddings)
print(f"PSI: {psi:.4f} ({'Drift detected' if psi > 0.1 else 'No drift'})")
print(f"KS test: p={ks['p_value']:.4f} ({'Drift detected' if ks['is_drifted'] else 'No drift'})")
Latency Monitoring
import time
from dataclasses import dataclass
from typing import List
import statistics
@dataclass
class LatencyMetrics:
p50: float
p95: float
p99: float
mean: float
std: float
count: int
class LatencyMonitor:
def __init__(self, window_size=1000):
self.window_size = window_size
self.latencies: List[float] = []
def record(self, latency_ms: float):
self.latencies.append(latency_ms)
if len(self.latencies) > self.window_size:
self.latencies.pop(0)
def get_metrics(self) -> LatencyMetrics:
if not self.latencies:
return None
return LatencyMetrics(
p50=np.percentile(self.latencies, 50),
p95=np.percentile(self.latencies, 95),
p99=np.percentile(self.latencies, 99),
mean=statistics.mean(self.latencies),
std=statistics.stdev(self.latencies) if len(self.latencies) > 1 else 0,
count=len(self.latencies),
)
def detect_anomalies(self, threshold_sigma=3):
"""Detect latency spikes using z-score."""
if len(self.latencies) < 30:
return []
mean = statistics.mean(self.latencies)
std = statistics.stdev(self.latencies)
anomalies = [
(i, lat)
for i, lat in enumerate(self.latencies[-50:])
if abs(lat - mean) > threshold_sigma * std
]
return anomalies
# NLP-specific latency tracking
class NLPLatencyTracker:
def __init__(self):
self.tokenization_latency = LatencyMonitor()
self.inference_latency = LatencyMonitor()
self.postprocessing_latency = LatencyMonitor()
self.total_latency = LatencyMonitor()
def track_request(self, request_fn, input_text):
start = time.time()
# Tokenization
tok_start = time.time()
tokens = tokenize(input_text)
self.tokenization_latency.record((time.time() - tok_start) * 1000)
# Inference
inf_start = time.time()
output = request_fn(tokens)
self.inference_latency.record((time.time() - inf_start) * 1000)
# Postprocessing
post_start = time.time()
result = postprocess(output)
self.postprocessing_latency.record((time.time() - post_start) * 1000)
self.total_latency.record((time.time() - start) * 1000)
return result
def summary(self):
return {
"tokenization": self.tokenization_latency.get_metrics(),
"inference": self.inference_latency.get_metrics(),
"postprocessing": self.postprocessing_latency.get_metrics(),
"total": self.total_latency.get_metrics(),
}
Quality Monitoring
class QualityMonitor:
def __init__(self, sample_rate=0.01):
self.sample_rate = sample_rate
self.metrics_history = []
def sample_for_review(self, prediction, input_text):
"""Randomly sample predictions for quality review."""
if np.random.random() < self.sample_rate:
return {
"input": input_text,
"prediction": prediction,
"timestamp": time.time(),
"needs_review": True,
}
return None
def compute_rolling_quality(self, window_hours=24):
"""Compute quality metrics over a time window."""
cutoff = time.time() - (window_hours * 3600)
recent = [m for m in self.metrics_history if m["timestamp"] > cutoff]
if not recent:
return None
return {
"accuracy": sum(1 for m in recent if m["correct"]) / len(recent),
"sample_count": len(recent),
"error_types": Counter(m["error_type"] for m in recent if not m["correct"]),
}
def detect_quality_drop(self, baseline_quality, current_quality, threshold=0.02):
"""Detect significant quality degradation."""
if baseline_quality is None or current_quality is None:
return False
drop = baseline_quality - current_quality
return drop > threshold
# Complete monitoring setup
class NLPModelMonitor:
def __init__(self, model, alert_callback=None):
self.model = model
self.latency_tracker = NLPLatencyTracker()
self.drift_detector = None
self.quality_monitor = QualityMonitor()
self.alert_callback = alert_callback
def log_prediction(self, input_text, prediction, latency_ms):
"""Log a prediction for monitoring."""
self.latency_tracker.record(latency_ms)
# Check for anomalies
if latency_ms > self.latency_tracker.total_latency.get_metrics().p99 * 2:
self._alert("latency_spike", {"latency": latency_ms, "input": input_text[:100]})
# Sample for quality review
sample = self.quality_monitor.sample_for_review(prediction, input_text)
if sample:
self._queue_for_review(sample)
def _alert(self, alert_type, details):
if self.alert_callback:
self.alert_callback(alert_type, details)
print(f"ALERT [{alert_type}]: {details}")
def get_dashboard_data(self):
"""Return data for monitoring dashboard."""
return {
"latency": self.latency_tracker.summary(),
"quality": self.quality_monitor.compute_rolling_quality(),
"drift": self.drift_detector.get_current_drift() if self.drift_detector else None,
}
Alert Configuration
| Alert Type | Metric | Warning | Critical | Action |
|---|---|---|---|---|
| Latency | P99 latency | > 300ms | > 500ms | Scale up |
| Throughput | Requests/sec | < 80% expected | < 50% expected | Investigate |
| Quality | Accuracy | Drop > 1% | Drop > 3% | Retrain |
| Drift | PSI score | > 0.1 | > 0.25 | Data review |
| Errors | Error rate | > 0.5% | > 2% | Rollback |
| Resources | GPU memory | > 85% | > 95% | Scale up |
Key Takeaways
- Latency monitoring should track P50, P95, and P99 percentiles
- Data drift detection using PSI or KL divergence catches distribution shifts early
- Quality sampling with human review provides ground truth for model performance
- Automated alerting enables rapid response to degradation
- Dashboard visualization helps teams understand system health at a glance