Interview Question (Hard) β Asked at: Google, Netflix, Uber, Amazon, Spotify
"Design an automated model retraining system that balances model freshness with computational cost. How do you determine when to retrain, what data to use, and how to validate the new model?"
Retraining Strategy Overview
Model retraining is critical for maintaining performance as data distributions change over time. The choice of retraining strategy depends on data velocity, computational resources, and latency requirements.
Retraining Strategy Comparison
| Strategy | Trigger | Latency | Cost | Best For |
|---|---|---|---|---|
| Scheduled | Time-based | Hours | Medium | Stable distributions |
| Triggered | Performance drop | Hours | Medium | Drift detection |
| Online | Per-example | Real-time | High | Streaming data |
| Active Learning | Uncertainty | Variable | Low | Label scarcity |
| Incremental | Batch arrival | Minutes | Low | High-frequency data |
Retraining Architecture Diagram
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Model Retraining System β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β Drift βββββΆβ Retrain βββββΆβ Training βββββΆβ Validate β β
β β Monitor β β Trigger β β Pipeline β β & Gate β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β β β β β
β βΌ βΌ βΌ βΌ β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β Data β β Schedule β β Resource β β Deploy/ β β
β β Version β β Manager β β Manager β β Rollback β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Scheduled Retraining
Airflow Scheduled Pipeline
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.sensors.external_task import ExternalTaskSensor
from datetime import datetime, timedelta
import json
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'email_on_failure': True,
'retries': 2,
'retry_delay': timedelta(minutes=10),
}
with DAG(
'scheduled_model_retraining',
default_args=default_args,
description='Scheduled model retraining pipeline',
schedule_interval='0 2 * * 0', # Weekly on Sunday at 2 AM
start_date=datetime(2024, 1, 1),
catchup=False,
tags=['retraining', 'scheduled'],
) as dag:
def prepare_training_data(**context):
"""Prepare data for retraining."""
import pandas as pd
from datetime import datetime, timedelta
# Get date range for training
execution_date = context['execution_date']
training_window = 90 # days
start_date = execution_date - timedelta(days=training_window)
end_date = execution_date
# Load data
df = pd.read_parquet(
f"s3://ml-data/training/{start_date:%Y%m%d}_{end_date:%Y%m%d}/"
)
# Save prepared data
output_path = f"/tmp/training_data_{execution_date:%Y%m%d}.parquet"
df.to_parquet(output_path)
return output_path
def train_model(**context):
"""Train model with new data."""
import xgboost as xgb
import mlflow
import pandas as pd
from sklearn.model_selection import train_test_split
ti = context['ti']
data_path = ti.xcom_pull(task_ids='prepare_data')
df = pd.read_parquet(data_path)
mlflow.set_experiment("scheduled_retraining")
with mlflow.start_run(run_name=f"retrain_{context['ds']}"):
# Prepare features
X = df.drop(columns=['label'])
y = df['label']
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42
)
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
params = {
'objective': 'binary:logistic',
'eval_metric': 'auc',
'max_depth': 6,
'learning_rate': 0.1,
}
model = xgb.train(
params,
dtrain,
num_boost_round=1000,
evals=[(dval, 'val')],
early_stopping_rounds=50
)
# Log metrics
val_pred = model.predict(dval)
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(y_val, val_pred)
mlflow.log_metric("auc_roc", auc)
mlflow.log_param("training_samples", len(X_train))
mlflow.log_param("validation_samples", len(X_val))
# Save model
model_path = f"/tmp/model_{context['ds']}.json"
model.save_model(model_path)
return {'model_path': model_path, 'auc': auc}
def evaluate_and_compare(**context):
"""Compare new model with current production model."""
import mlflow
import json
ti = context['ti']
training_results = ti.xcom_pull(task_ids='train_model')
# Load current production model metrics
current_metrics = json.loads(
open('/models/production/metrics.json').read()
)
new_auc = training_results['auc']
current_auc = current_metrics['auc_roc']
# Decision logic
improvement_threshold = 0.01 # 1% improvement required
if new_auc > current_auc + improvement_threshold:
action = "promote"
reason = f"New model improved AUC by {new_auc - current_auc:.4f}"
elif new_auc < current_auc - 0.02:
action = "keep"
reason = f"New model degraded AUC by {current_auc - new_auc:.4f}"
else:
action = "keep"
reason = "Improvement below threshold"
return {
'action': action,
'reason': reason,
'new_auc': new_auc,
'current_auc': current_auc
}
# Task definitions
prepare_data = PythonOperator(
task_id='prepare_data',
python_callable=prepare_training_data,
)
train = DockerOperator(
task_id='train_model',
image='registry.example.com/ml/training:latest',
command='python train.py',
auto_remove=True,
)
evaluate = PythonOperator(
task_id='evaluate_model',
python_callable=evaluate_and_compare,
)
promote = DockerOperator(
task_id='promote_model',
image='registry.example.com/ml/deployment:latest',
command='python promote.py --model-path {{ ti.xcom_pull(task_ids="train_model")["model_path"] }}',
)
prepare_data >> train >> evaluate >> promote
Triggered Retraining
Event-Driven Retraining System
from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Optional
import json
import logging
from datetime import datetime, timedelta
from kafka import KafkaConsumer, KafkaProducer
import redis
logger = logging.getLogger(__name__)
class RetrainingTrigger(Enum):
DATA_DRIFT = "data_drift"
PERFORMANCE_DECAY = "performance_decay"
SCHEDULE = "schedule"
MANUAL = "manual"
DATA_VOLUME = "data_volume"
@dataclass
class RetrainingEvent:
trigger: RetrainingTrigger
timestamp: datetime
metadata: Dict
priority: int # 1 (low) to 5 (high)
class TriggeredRetrainingManager:
def __init__(self, config: dict):
self.config = config
self.redis = redis.Redis(
host=config['redis_host'],
port=config['redis_port']
)
self.consumer = KafkaConsumer(
'retraining-triggers',
bootstrap_servers=config['kafka_servers'],
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
self.producer = KafkaProducer(
bootstrap_servers=config['kafka_servers'],
value_serializer=lambda v: json.dumps(v, default=str).encode('utf-8')
)
# Retraining cooldown
self.cooldown_hours = config.get('cooldown_hours', 24)
def should_retrain(self, event: RetrainingEvent) -> tuple:
"""Determine if retraining should be triggered."""
# Check cooldown
last_retrain = self.redis.get("last_retrain_time")
if last_retrain:
last_time = datetime.fromisoformat(last_retrain.decode())
if datetime.now() - last_time < timedelta(hours=self.cooldown_hours):
return False, "Cooldown period active"
# Check priority
if event.priority < 3:
return False, "Priority too low"
# Check trigger-specific conditions
if event.trigger == RetrainingTrigger.DATA_DRIFT:
drift_score = event.metadata.get('drift_score', 0)
if drift_score < 0.3:
return False, "Drift score below threshold"
elif event.trigger == RetrainingTrigger.PERFORMANCE_DECAY:
current_accuracy = event.metadata.get('current_accuracy', 1)
baseline_accuracy = event.metadata.get('baseline_accuracy', 1)
decay = baseline_accuracy - current_accuracy
if decay < 0.05: # 5% decay threshold
return False, "Performance decay below threshold"
return True, "Retraining recommended"
def process_event(self, event: RetrainingEvent):
"""Process a retraining trigger event."""
should_retrain, reason = self.should_retrain(event)
logger.info(
f"Processing trigger: {event.trigger.value}, "
f"Decision: {'RETRAIN' if should_retrain else 'SKIP'}, "
f"Reason: {reason}"
)
if should_retrain:
# Create retraining job
job = {
'trigger': event.trigger.value,
'timestamp': datetime.now().isoformat(),
'metadata': event.metadata,
'priority': event.priority,
'status': 'queued'
}
# Queue retraining job
self.producer.send(
'retraining-jobs',
value=job
)
# Update cooldown
self.redis.setex(
"last_retrain_time",
timedelta(hours=self.cooldown_hours),
datetime.now().isoformat()
)
# Log event
self._log_retraining_event(event, "triggered", reason)
def listen_for_events(self):
"""Listen for retraining trigger events."""
logger.info("Listening for retraining events...")
for message in self.consumer:
try:
event_data = message.value
event = RetrainingEvent(
trigger=RetrainingTrigger(event_data['trigger']),
timestamp=datetime.fromisoformat(event_data['timestamp']),
metadata=event_data['metadata'],
priority=event_data.get('priority', 3)
)
self.process_event(event)
except Exception as e:
logger.error(f"Error processing event: {e}")
def _log_retraining_event(self, event: RetrainingEvent,
action: str, reason: str):
"""Log retraining event for audit."""
log_entry = {
'timestamp': datetime.now().isoformat(),
'trigger': event.trigger.value,
'action': action,
'reason': reason,
'metadata': event.metadata
}
self.redis.lpush(
"retraining_log",
json.dumps(log_entry, default=str)
)
Performance Decay Detection
import numpy as np
from typing import Dict, List, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import pandas as pd
@dataclass
class PerformanceMetric:
timestamp: datetime
value: float
window: str # '1h', '24h', '7d'
class PerformanceDecayDetector:
def __init__(self, baseline_metrics: Dict[str, float],
decay_thresholds: Dict[str, float]):
"""
Args:
baseline_metrics: Baseline performance metrics
decay_thresholds: Threshold for each metric to trigger retraining
"""
self.baseline_metrics = baseline_metrics
self.decay_thresholds = decay_thresholds
self.metric_history = {}
def add_metric(self, metric_name: str, value: float,
window: str = '24h'):
"""Add a performance metric measurement."""
if metric_name not in self.metric_history:
self.metric_history[metric_name] = []
self.metric_history[metric_name].append(
PerformanceMetric(
timestamp=datetime.now(),
value=value,
window=window
)
)
def detect_decay(self) -> Dict[str, Dict]:
"""Detect performance decay across all metrics."""
decay_results = {}
for metric_name, history in self.metric_history.items():
if not history:
continue
# Get recent values
recent_window = timedelta(hours=24)
recent_values = [
m.value for m in history
if datetime.now() - m.timestamp < recent_window
]
if len(recent_values) < 10:
continue
# Calculate statistics
current_mean = np.mean(recent_values)
current_std = np.std(recent_values)
baseline = self.baseline_metrics.get(metric_name, current_mean)
# Calculate decay
decay = baseline - current_mean
decay_percentage = (decay / baseline * 100) if baseline > 0 else 0
# Check threshold
threshold = self.decay_thresholds.get(metric_name, 0.05)
decay_detected = decay > threshold
# Statistical significance test
from scipy import stats
t_stat, p_value = stats.ttest_1samp(
recent_values, baseline
)
decay_results[metric_name] = {
'current_value': current_mean,
'baseline_value': baseline,
'decay': decay,
'decay_percentage': decay_percentage,
'std': current_std,
'p_value': p_value,
'statistically_significant': p_value < 0.05,
'decay_detected': decay_detected,
'threshold': threshold
}
return decay_results
def calculate_retraining_priority(self, decay_results: Dict) -> int:
"""Calculate retraining priority based on decay severity."""
priority = 1
for metric_name, result in decay_results.items():
if result['decay_detected']:
if result['decay_percentage'] > 20:
priority = 5
elif result['decay_percentage'] > 10:
priority = max(priority, 4)
elif result['decay_percentage'] > 5:
priority = max(priority, 3)
else:
priority = max(priority, 2)
return priority
βΉοΈ
Triggered retraining balances computational cost with model freshness. Use performance decay detection with statistical significance tests to avoid false positives from normal metric variance.
Online Learning
Online Gradient Descent
import numpy as np
from typing import Optional
from collections import deque
class OnlineLinearModel:
"""Online learning model using gradient descent."""
def __init__(self, n_features: int, learning_rate: float = 0.01,
regularization: float = 0.001):
self.n_features = n_features
self.learning_rate = learning_rate
self.regularization = regularization
# Initialize weights
self.weights = np.zeros(n_features)
self.bias = 0
# Running statistics
self.n_samples = 0
self.loss_history = deque(maxlen=1000)
def _sigmoid(self, z):
"""Sigmoid activation function."""
return 1 / (1 + np.exp(-np.clip(z, -500, 500)))
def predict(self, X: np.ndarray) -> np.ndarray:
"""Make predictions."""
z = X @ self.weights + self.bias
return self._sigmoid(z)
def update(self, X: np.ndarray, y: float):
"""Update model with a single example (online learning)."""
# Forward pass
prediction = self.predict(X.reshape(1, -1))[0]
# Calculate gradient
error = prediction - y
# Update weights
gradient = error * X
self.weights -= self.learning_rate * (
gradient + self.regularization * self.weights
)
self.bias -= self.learning_rate * error
# Track loss
loss = -y * np.log(prediction + 1e-7) - (1 - y) * np.log(1 - prediction + 1e-7)
self.loss_history.append(loss)
self.n_samples += 1
return loss
def update_batch(self, X_batch: np.ndarray, y_batch: np.ndarray):
"""Update model with a mini-batch."""
total_loss = 0
for X, y in zip(X_batch, y_batch):
loss = self.update(X, y)
total_loss += loss
return total_loss / len(y_batch)
def get_metrics(self) -> dict:
"""Get current model metrics."""
return {
'n_samples': self.n_samples,
'avg_loss': np.mean(self.loss_history) if self.loss_history else 0,
'weights_norm': np.linalg.norm(self.weights),
'bias': self.bias
}
class OnlineEnsemble:
"""Ensemble of online learning models."""
def __init__(self, n_models: int, n_features: int,
learning_rate: float = 0.01):
self.models = [
OnlineLinearModel(n_features, learning_rate)
for _ in range(n_models)
]
self.weights = np.ones(n_models) / n_models
def predict(self, X: np.ndarray) -> np.ndarray:
"""Weighted ensemble prediction."""
predictions = np.array([
model.predict(X) for model in self.models
])
return np.average(predictions, axis=0, weights=self.weights)
def update(self, X: np.ndarray, y: float):
"""Update all models and adjust ensemble weights."""
# Get individual predictions
predictions = np.array([
model.predict(X.reshape(1, -1))[0]
for model in self.models
])
# Calculate losses
losses = np.array([
-y * np.log(pred + 1e-7) - (1 - y) * np.log(1 - pred + 1e-7)
for pred in predictions
])
# Update ensemble weights (exponential weighting)
self.weights *= np.exp(-losses)
self.weights /= self.weights.sum()
# Update individual models
for model in self.models:
model.update(X, y)
def get_model_weights(self) -> dict:
"""Get current ensemble weights."""
return {
f'model_{i}': float(w)
for i, w in enumerate(self.weights)
}
River - Online Machine Learning Library
from river import (
linear_model, preprocessing, metrics, compose, utils
)
from river import ensemble as river_ensemble
import numpy as np
class RiverOnlineClassifier:
"""Online classifier using River library."""
def __init__(self):
# Create pipeline with preprocessing
self.model = compose.Pipeline(
preprocessing.StandardScaler(),
river_ensemble.AdaptiveRandomForestClassifier(
n_models=10,
seed=42
)
)
# Metrics
self.metric = metrics Accuracy()
self.auc_metric = metrics.ROCAUC()
# Buffer for predictions
self.prediction_buffer = []
def learn_one(self, x: dict, y: int):
"""Learn from a single example."""
# Get prediction before learning
y_pred = self.model.predict_one(x)
if y_pred is not None:
self.metric.update(y, y_pred)
self.auc_metric.update(y, self.model.predict_proba_one(x)[True])
# Learn from example
self.model.learn_one(x, y)
def predict_one(self, x: dict) -> int:
"""Predict a single example."""
return self.model.predict_one(x)
def predict_proba_one(self, x: dict) -> dict:
"""Predict class probabilities."""
return self.model.predict_proba_one(x)
def get_metrics(self) -> dict:
"""Get current metrics."""
return {
'accuracy': self.metric.get(),
'auc': self.auc_metric.get()
}
class OnlineFeatureUpdater:
"""Update features in real-time for online learning."""
def __init__(self, feature_store_url: str):
self.feature_store_url = feature_store_url
self.feature_cache = {}
def update_features(self, entity_id: str, features: dict):
"""Update features in the feature store."""
# Update local cache
self.feature_cache[entity_id] = {
'features': features,
'updated_at': datetime.now()
}
# Update feature store (async)
import asyncio
asyncio.create_task(
self._update_feature_store(entity_id, features)
)
async def _update_feature_store(self, entity_id: str, features: dict):
"""Update feature store asynchronously."""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.put(
f"{self.feature_store_url}/features/{entity_id}",
json=features
) as response:
if response.status != 200:
print(f"Failed to update features for {entity_id}")
β οΈ
Online learning is sensitive to data quality. Implement input validation and outlier detection before updating the model. Consider using bounded buffers and gradient clipping to prevent instability.
Active Learning
Uncertainty Sampling
import numpy as np
from typing import List, Tuple
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV
class ActiveLearner:
"""Active learning with uncertainty sampling."""
def __init__(self, base_model, n_initial: int = 100):
self.base_model = base_model
self.n_initial = n_initial
# Labeled data
self.X_labeled = None
self.y_labeled = None
# Unlabeled pool
self.X_unlabeled = None
# Query budget
self.query_budget = 1000
self.queries_made = 0
def initialize(self, X_pool: np.ndarray, y_pool: np.ndarray):
"""Initialize with a small labeled set."""
# Randomly select initial samples
initial_idx = np.random.choice(
len(X_pool),
size=min(self.n_initial, len(X_pool)),
replace=False
)
self.X_labeled = X_pool[initial_idx]
self.y_labeled = y_pool[initial_idx]
# Remove from pool
mask = np.ones(len(X_pool), dtype=bool)
mask[initial_idx] = False
self.X_unlabeled = X_pool[mask]
# Fit initial model
self._fit_model()
def _fit_model(self):
"""Fit the model on labeled data."""
# Calibrate probabilities
self.model = CalibratedClassifierCV(
self.base_model,
cv=3
)
self.model.fit(self.X_labeled, self.y_labeled)
def query_uncertainty(self, n_samples: int = 10) -> np.ndarray:
"""Query samples with highest uncertainty."""
# Get predictions
probabilities = self.model.predict_proba(self.X_unlabeled)
# Calculate uncertainty (entropy)
entropy = -np.sum(
probabilities * np.log(probabilities + 1e-7),
axis=1
)
# Select most uncertain samples
uncertain_idx = np.argsort(entropy)[-n_samples:]
return uncertain_idx
def query_margin(self, n_samples: int = 10) -> np.ndarray:
"""Query samples with smallest margin between top 2 classes."""
probabilities = self.model.predict_proba(self.X_unlabeled)
# Sort probabilities
sorted_probs = np.sort(probabilities, axis=1)[:, ::-1]
# Calculate margin
margin = sorted_probs[:, 0] - sorted_probs[:, 1]
# Select samples with smallest margin
margin_idx = np.argsort(margin)[:n_samples]
return margin_idx
def query_random_forest_uncertainty(self, n_samples: int = 10) -> np.ndarray:
"""Query based on random forest disagreement."""
# Use base model's trees for disagreement
if hasattr(self.base_model, 'estimators_'):
# Get predictions from each tree
tree_predictions = np.array([
tree.predict_proba(self.X_unlabeled)[:, 1]
for tree in self.base_model.estimators_
])
# Calculate variance (disagreement)
variance = np.var(tree_predictions, axis=0)
# Select most disagreed samples
uncertain_idx = np.argsort(variance)[-n_samples:]
return uncertain_idx
return self.query_uncertainty(n_samples)
def add_labels(self, X_new: np.ndarray, y_new: np.ndarray):
"""Add newly labeled data."""
self.X_labeled = np.vstack([self.X_labeled, X_new])
self.y_labeled = np.concatenate([self.y_labeled, y_new])
# Remove from unlabeled pool
# (simplified - in practice need to find and remove specific indices)
# Refit model
self._fit_model()
self.queries_made += len(X_new)
def should_query(self) -> bool:
"""Check if we should query more samples."""
return (
self.queries_made < self.query_budget and
len(self.X_unlabeled) > 0
)
def get_statistics(self) -> dict:
"""Get active learning statistics."""
return {
'n_labeled': len(self.X_labeled),
'n_unlabeled': len(self.X_unlabeled),
'queries_made': self.queries_made,
'query_budget': self.query_budget,
'class_distribution': {
'positive': int(np.sum(self.y_labeled == 1)),
'negative': int(np.sum(self.y_labeled == 0))
}
}
Retraining Pipeline Orchestration
Complete Retraining Orchestrator
from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Optional
import json
import logging
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
class RetrainingStrategy(Enum):
SCHEDULED = "scheduled"
TRIGGERED = "triggered"
ONLINE = "online"
ACTIVE_LEARNING = "active_learning"
@dataclass
class RetrainingJob:
job_id: str
strategy: RetrainingStrategy
model_name: str
trigger_reason: str
priority: int
created_at: datetime
status: str = "pending"
metadata: Dict = None
class RetrainingOrchestrator:
def __init__(self, config: dict):
self.config = config
self.job_queue = []
self.active_jobs = {}
self.completed_jobs = []
def create_retraining_job(self, strategy: RetrainingStrategy,
model_name: str, trigger_reason: str,
priority: int = 3) -> RetrainingJob:
"""Create a new retraining job."""
job = RetrainingJob(
job_id=f"job_{datetime.now():%Y%m%d%H%M%S}",
strategy=strategy,
model_name=model_name,
trigger_reason=trigger_reason,
priority=priority,
created_at=datetime.now(),
metadata={}
)
self.job_queue.append(job)
# Sort by priority
self.job_queue.sort(key=lambda x: x.priority, reverse=True)
return job
def execute_job(self, job: RetrainingJob) -> dict:
"""Execute a retraining job."""
job.status = "running"
self.active_jobs[job.job_id] = job
try:
# Step 1: Prepare data
data = self._prepare_training_data(job)
# Step 2: Train model
training_results = self._train_model(job, data)
# Step 3: Evaluate model
evaluation_results = self._evaluate_model(job, training_results)
# Step 4: Deploy if successful
if evaluation_results['passed']:
self._deploy_model(job, training_results)
job.status = "completed"
else:
job.status = "failed"
result = {
'job_id': job.job_id,
'status': job.status,
'training_results': training_results,
'evaluation_results': evaluation_results
}
except Exception as e:
job.status = "error"
result = {
'job_id': job.job_id,
'status': 'error',
'error': str(e)
}
finally:
# Move to completed
self.completed_jobs.append(job)
if job.job_id in self.active_jobs:
del self.active_jobs[job.job_id]
return result
def _prepare_training_data(self, job: RetrainingJob) -> pd.DataFrame:
"""Prepare training data based on strategy."""
if job.strategy == RetrainingStrategy.SCHEDULED:
# Use last N days of data
end_date = datetime.now()
start_date = end_date - timedelta(days=90)
data = pd.read_parquet(
f"s3://ml-data/training/{start_date:%Y%m%d}_{end_date:%Y%m%d}/"
)
elif job.strategy == RetrainingStrategy.TRIGGERED:
# Use data from last retraining to now
last_retrain = self._get_last_retrain_time(job.model_name)
data = pd.read_parquet(
f"s3://ml-data/training/{last_retrain:%Y%m%d}_{datetime.now():%Y%m%d}/"
)
elif job.strategy == RetrainingStrategy.ONLINE:
# Use recent streaming data
data = self._get_recent_streaming_data(hours=24)
else:
# Default: use last 30 days
data = self._get_training_data(days=30)
return data
def _train_model(self, job: RetrainingJob, data: pd.DataFrame) -> dict:
"""Train model using appropriate method."""
import xgboost as xgb
from sklearn.model_selection import train_test_split
import mlflow
mlflow.set_experiment(f"retraining_{job.model_name}")
with mlflow.start_run(run_name=job.job_id):
X = data.drop(columns=['label'])
y = data['label']
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42
)
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
params = {
'objective': 'binary:logistic',
'eval_metric': 'auc',
'max_depth': 6,
'learning_rate': 0.1,
}
model = xgb.train(
params,
dtrain,
num_boost_round=1000,
evals=[(dval, 'val')],
early_stopping_rounds=50
)
# Log metrics
val_pred = model.predict(dval)
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(y_val, val_pred)
mlflow.log_metric("auc_roc", auc)
mlflow.log_param("strategy", job.strategy.value)
mlflow.log_param("trigger_reason", job.trigger_reason)
return {
'model': model,
'auc': auc,
'training_samples': len(X_train),
'validation_samples': len(X_val)
}
def _evaluate_model(self, job: RetrainingJob,
training_results: dict) -> dict:
"""Evaluate model against quality gates."""
thresholds = {
'auc_roc': 0.90,
'min_improvement': 0.01
}
# Get current production model performance
current_metrics = self._get_production_metrics(job.model_name)
new_auc = training_results['auc']
current_auc = current_metrics.get('auc_roc', 0)
# Check absolute threshold
if new_auc < thresholds['auc_roc']:
return {
'passed': False,
'reason': f"AUC {new_auc:.4f} below threshold {thresholds['auc_roc']}"
}
# Check improvement for triggered retraining
if job.strategy == RetrainingStrategy.TRIGGERED:
improvement = new_auc - current_auc
if improvement < thresholds['min_improvement']:
return {
'passed': False,
'reason': f"Improvement {improvement:.4f} below threshold"
}
return {
'passed': True,
'new_auc': new_auc,
'current_auc': current_auc,
'improvement': new_auc - current_auc
}
def _deploy_model(self, job: RetrainingJob,
training_results: dict):
"""Deploy new model to production."""
# Save model
model_path = f"s3://ml-models/{job.model_name}/{job.job_id}/model.json"
training_results['model'].save_model(model_path)
# Update model registry
self._update_model_registry(
job.model_name,
job.job_id,
training_results
)
# Trigger deployment
self._trigger_deployment(job.model_name, job.job_id)
def _get_last_retrain_time(self, model_name: str) -> datetime:
"""Get last retraining time for a model."""
# Query model registry
return datetime.now() - timedelta(days=7)
def _get_production_metrics(self, model_name: str) -> dict:
"""Get current production model metrics."""
# Query monitoring system
return {'auc_roc': 0.92}
def _update_model_registry(self, model_name: str,
version: str, metrics: dict):
"""Update model registry with new version."""
pass
def _trigger_deployment(self, model_name: str, version: str):
"""Trigger deployment pipeline."""
pass
def get_statistics(self) -> dict:
"""Get retraining statistics."""
return {
'queued_jobs': len(self.job_queue),
'active_jobs': len(self.active_jobs),
'completed_jobs': len(self.completed_jobs),
'jobs_by_strategy': {
strategy.value: len([
j for j in self.completed_jobs
if j.strategy == strategy
])
for strategy in RetrainingStrategy
}
}
βΉοΈ
A robust retraining system combines multiple strategies: scheduled retraining for baseline freshness, triggered retraining for drift response, and online learning for real-time adaptation. Monitor computational costs and model performance to optimize the retraining schedule.
Summary
Model retraining strategies include:
- Scheduled: Time-based retraining for predictable workloads
- Triggered: Event-driven retraining based on drift or performance decay
- Online: Continuous learning from streaming data
- Active Learning: Strategic sample selection for label efficiency
Choose the strategy that balances model freshness, computational cost, and labeling requirements for your use case.