Designing ML Systems at Scale
Building a machine learning model is the easy part. Deploying it reliably, monitoring it in production, and iterating on it safely β that's where real engineering happens. This lesson walks you through designing ML systems that scale.
The ML System Lifecycle
A production ML system is far more than a model artifact. It encompasses data ingestion, feature engineering, training, serving, monitoring, and feedback loops.
Feature Stores
A feature store centralizes feature computation and serving, ensuring consistency between training and inference.
from feast import FeatureStore, Entity, Feature, ValueType
from feast import FileSource, BigQuerySource
from datetime import timedelta
# Define entities
driver = Entity(
name="driver_id",
value_type=ValueType.INT64,
description="Driver identifier"
)
# Define features
avg_daily_rides = Feature(
name="avg_daily_rides",
value_type=ValueType.FLOAT,
description="Average rides per day over last 30 days"
)
# Create a feature view
driver_stats = FeatureView(
name="driver_statistics",
entities=["driver_id"],
ttl=timedelta(days=1),
features=[avg_daily_rides],
online=True,
source=BigQuerySource(
table="project.dataset.driver_stats",
event_timestamp_column="event_timestamp"
)
)
# Store and retrieve
store = FeatureStore(repo_path=".")
store.apply([driver, driver_stats])
# Online serving
features = store.get_online_features(
features=["driver_statistics:avg_daily_rides"],
entity_rows=[{"driver_id": 1001}]
).to_dict()
Online vs Batch Serving
| Aspect | Batch | Online |
|---|---|---|
| Latency | Minutes to hours | Milliseconds |
| Cost | Lower | Higher |
| Use Case | Reports, retraining | Real-time predictions |
| Freshness | Periodic | Near real-time |
# Batch prediction pipeline
import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("BatchPrediction").getOrCreate()
def batch_predict(model_path, input_path, output_path):
df = spark.read.parquet(input_path)
# Broadcast model to all executors
broadcast_model = spark.sparkContext.broadcast(model_path)
def predict_partition(partition):
import joblib
model = joblib.load(broadcast_model.value)
rows = list(partition)
features = [row.features for row in rows]
preds = model.predict(features)
for row, pred in zip(rows, preds):
yield Row(id=row.id, prediction=float(pred))
predictions = df.repartition(100).mapPartitions(predict_partition)
predictions.write.parquet(output_path)
# Online prediction service
from fastapi import FastAPI
import numpy as np
app = FastAPI()
@app.post("/predict")
async def predict(features: dict):
model = load_model_from_registry("production", 1)
X = np.array(features["values"]).reshape(1, -1)
prediction = model.predict(X)
return {"prediction": prediction.tolist()}
Monitoring and Observability
Production models degrade silently without monitoring. Track data quality, model performance, and system health.
from dataclasses import dataclass
from typing import Optional
import numpy as np
@dataclass
class ModelMetrics:
prediction_mean: float
prediction_std: float
null_rate: float
latency_p99: float
class ModelMonitor:
def __init__(self, reference_distribution):
self.reference = reference_distribution
self.alerts = []
def check_data_drift(self, current_batch):
from scipy.stats import ks_2samp
stat, p_value = ks_2samp(self.reference, current_batch)
if p_value < 0.01:
self.alerts.append(f"Data drift detected: KS={stat:.4f}, p={p_value:.6f}")
return True
return False
def check_prediction_distribution(self, predictions):
current_mean = np.mean(predictions)
ref_mean = np.mean(self.reference)
drift = abs(current_mean - ref_mean) / (np.std(self.reference) + 1e-8)
if drift > 3:
self.alerts.append(f"Prediction distribution shift: {drift:.2f} std devs")
return True
return False
def check_latency(self, latencies_ms, threshold=200):
p99 = np.percentile(latencies_ms, 99)
if p99 > threshold:
self.alerts.append(f"Latency P99={p99:.1f}ms exceeds threshold")
return True
return False
A/B Testing Infrastructure
Proper experimentation requires statistical rigor and infrastructure to avoid common pitfalls.
import numpy as np
from scipy import stats
class ABTestAnalyzer:
def __init__(self, alpha=0.05, mde=0.02):
self.alpha = alpha
self.mde = mde
def required_sample_size(self, baseline_rate, power=0.8):
p1 = baseline_rate
p2 = baseline_rate * (1 + self.mde)
pooled = (p1 + p2) / 2
z_alpha = stats.norm.ppf(1 - self.alpha / 2)
z_beta = stats.norm.ppf(power)
numerator = (z_alpha * np.sqrt(2 * pooled * (1 - pooled)) +
z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) ** 2
denominator = (p2 - p1) ** 2
return int(np.ceil(numerator / denominator))
def analyze(self, control_conversions, control_total,
treatment_conversions, treatment_total):
p_control = control_conversions / control_total
p_treatment = treatment_conversions / treatment_total
pooled = (control_conversions + treatment_conversions) / \
(control_total + treatment_total)
se = np.sqrt(pooled * (1 - pooled) * (1/control_total + 1/treatment_total))
z_stat = (p_treatment - p_control) / se
p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))
ci_lower = (p_treatment - p_control) - 1.96 * se
ci_upper = (p_treatment - p_control) + 1.96 * se
return {
"lift": p_treatment - p_control,
"lift_pct": (p_treatment - p_control) / p_control * 100,
"p_value": p_value,
"significant": p_value < self.alpha,
"ci_95": (ci_lower, ci_upper)
}
Online Learning Rate Schedules
When deploying ML systems, learning rate scheduling is critical for convergence:
where is the initial learning rate, is the decay rate, and is the current step.
Key Takeaways
- Feature stores eliminate train-serve skew
- Choose batch vs online based on latency requirements and cost
- Monitor data distributions, predictions, and system metrics continuously
- A/B tests need proper sample sizing and sequential analysis to avoid peeking problems