Building Experimentation Platforms
A/B testing at scale requires infrastructure. Learn to build an experimentation platform that handles assignment, analysis, and governance.
Platform Architecture
Feature Flag System
import hashlib
import json
from datetime import datetime
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
@dataclass
class FeatureFlag:
name: str
enabled: bool
rules: list
default_variant: str = "off"
variants: Dict[str, float] = field(default_factory=dict)
kill_switch: bool = False
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
class FeatureFlagService:
def __init__(self, storage):
self.storage = storage
self.flags: Dict[str, FeatureFlag] = {}
def get_variant(self, flag_name: str, user_id: str, context: dict = None) -> str:
"""Determine variant for a user"""
flag = self.flags.get(flag_name)
if not flag or not flag.enabled or flag.kill_switch:
return "off"
# Check rules (targeting)
for rule in flag.rules:
if self._matches_rule(rule, user_id, context):
return self._hash_to_variant(flag, user_id, rule.get("variants", flag.variants))
return flag.default_variant
def _matches_rule(self, rule: dict, user_id: str, context: dict) -> bool:
"""Check if user matches a targeting rule"""
conditions = rule.get("conditions", [])
for condition in conditions:
field = condition.get("field")
operator = condition.get("operator")
value = condition.get("value")
user_value = (context or {}).get(field)
if operator == "equals" and user_value != value:
return False
elif operator == "in" and user_value not in value:
return False
elif operator == "percentage":
hash_val = int(hashlib.md5(f"{user_id}:{field}".encode()).hexdigest(), 16) % 100
if hash_val >= value:
return False
return True
def _hash_to_variant(self, flag: FeatureFlag, user_id: str, variants: dict) -> str:
"""Deterministically assign variant based on hash"""
hash_val = int(hashlib.md5(f"{flag.name}:{user_id}".encode()).hexdigest(), 16) % 100
cumulative = 0
for variant, weight in variants.items():
cumulative += weight * 100
if hash_val < cumulative:
return variant
return list(variants.keys())[-1]
def set_kill_switch(self, flag_name: str, enabled: bool):
"""Emergency kill switch"""
if flag_name in self.flags:
self.flags[flag_name].kill_switch = enabled
# Usage
service = FeatureFlagService(storage=None)
service.flags["new_checkout"] = FeatureFlag(
name="new_checkout",
enabled=True,
rules=[{
"conditions": [
{"field": "country", "operator": "in", "value": ["US", "CA"]},
{"field": "percentage", "operator": "percentage", "value": 50}
],
"variants": {"control": 0.5, "treatment": 0.5}
}]
)
variant = service.get_variant("new_checkout", "user_123", {"country": "US"})
Gradual Rollout System
import time
from dataclasses import dataclass
from typing import List, Callable
from enum import Enum
class RolloutStage(Enum):
CANARY = "canary"
EARLY_ADOPTION = "early_adoption"
RAMPING = "ramping"
FULL = "full"
@dataclass
class RolloutConfig:
stages: List[dict]
auto_rollback: bool = True
health_checks: List[Callable] = None
class GradualRollout:
def __init__(self, config: RolloutConfig, flag_service: FeatureFlagService):
self.config = config
self.flag_service = flag_service
self.current_stage = 0
self.start_time = time.time()
def advance_stage(self):
"""Move to next rollout stage"""
if self.current_stage >= len(self.config.stages) - 1:
return False
self.current_stage += 1
stage = self.config.stages[self.current_stage]
# Update feature flag
self.flag_service.flags["feature_x"].rules[0]["variants"] = {
"control": 1 - stage["traffic_pct"],
"treatment": stage["traffic_pct"]
}
return True
def check_health(self, metrics: dict) -> bool:
"""Run health checks before advancing"""
if not self.config.health_checks:
return True
for check in self.config.health_checks:
if not check(metrics):
return False
return True
def should_rollback(self, metrics: dict) -> bool:
"""Determine if rollback is needed"""
if not self.config.auto_rollback:
return False
stage = self.config.stages[self.current_stage]
thresholds = stage.get("rollback_thresholds", {})
for metric, threshold in thresholds.items():
if metrics.get(metric, 0) < threshold:
return True
return False
def rollback(self):
"""Rollback to previous stage"""
if self.current_stage > 0:
self.current_stage -= 1
stage = self.config.stages[self.current_stage]
self.flag_service.flags["feature_x"].rules[0]["variants"] = {
"control": 1 - stage["traffic_pct"],
"treatment": stage["traffic_pct"]
}
# Health check functions
def latency_check(metrics, threshold_ms=200):
return metrics.get('p99_latency', 0) < threshold_ms
def error_rate_check(metrics, threshold=0.01):
return metrics.get('error_rate', 0) < threshold
def conversion_rate_check(metrics, min_lift=-0.05):
return metrics.get('conversion_lift', 0) > min_lift
# Configure rollout
config = RolloutConfig(
stages=[
{"name": "canary", "traffic_pct": 0.01, "duration_hours": 2,
"rollback_thresholds": {"error_rate": 0.01, "p99_latency": 300}},
{"name": "early", "traffic_pct": 0.1, "duration_hours": 24,
"rollback_thresholds": {"error_rate": 0.005, "p99_latency": 200}},
{"name": "ramp_50", "traffic_pct": 0.5, "duration_hours": 48,
"rollback_thresholds": {"error_rate": 0.005, "p99_latency": 200}},
{"name": "full", "traffic_pct": 1.0, "duration_hours": 0}
],
auto_rollback=True,
health_checks=[
lambda m: latency_check(m),
lambda m: error_rate_check(m),
lambda m: conversion_rate_check(m)
]
)
Statistical Engine
import numpy as np
from scipy import stats
from dataclasses import dataclass
@dataclass
class ExperimentConfig:
alpha: float = 0.05
power: float = 0.8
mde: float = 0.02
method: str = "frequentist" # or "bayesian"
class StatisticalEngine:
def __init__(self, config: ExperimentConfig):
self.config = config
def compute_sample_size(self, baseline_mean, baseline_var):
"""Compute required sample size"""
if self.config.method == "frequentist":
return self._frequentist_sample_size(baseline_mean, baseline_var)
else:
return self._bayesian_sample_size(baseline_mean, baseline_var)
def _frequentist_sample_size(self, baseline_mean, baseline_var):
z_alpha = stats.norm.ppf(1 - self.config.alpha / 2)
z_beta = stats.norm.ppf(self.config.power)
effect_size = self.config.mde * baseline_mean
n = 2 * (baseline_var / effect_size**2) * (z_alpha + z_beta)**2
return int(np.ceil(n))
def _bayesian_sample_size(self, baseline_mean, baseline_var):
"""Approximate Bayesian sample size"""
# For Bayesian, we want posterior to be precise enough
target_width = self.config.mde * baseline_mean * 0.5
n = 4 * baseline_var / (target_width**2)
return int(np.ceil(n))
def analyze(self, control_data, treatment_data):
"""Run analysis based on configured method"""
if self.config.method == "frequentist":
return self._frequentist_analysis(control_data, treatment_data)
else:
return self._bayesian_analysis(control_data, treatment_data)
def _frequentist_analysis(self, control, treatment):
n1, n2 = len(control), len(treatment)
m1, m2 = np.mean(control), np.mean(treatment)
v1, v2 = np.var(control, ddof=1), np.var(treatment, ddof=1)
# Welch's t-test
se = np.sqrt(v1/n1 + v2/n2)
t_stat = (m2 - m1) / se
df = (v1/n1 + v2/n2)**2 / ((v1/n1)**2/(n1-1) + (v2/n2)**2/(n2-1))
p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df))
ci_lower = (m2 - m1) - 1.96 * se
ci_upper = (m2 - m1) + 1.96 * se
return {
"method": "frequentist",
"control_mean": m1,
"treatment_mean": m2,
"lift": (m2 - m1) / m1,
"p_value": p_value,
"ci_95": (ci_lower, ci_upper),
"significant": p_value < self.config.alpha
}
def _bayesian_analysis(self, control, treatment):
"""Bayesian analysis with Beta-Binomial model"""
# Assume binary outcomes
s1, n1 = sum(control), len(control)
s2, n2 = sum(treatment), len(treatment)
# Beta posteriors
alpha_prior, beta_prior = 1, 1
post_control = stats.beta(alpha_prior + s1, beta_prior + n1 - s1)
post_treatment = stats.beta(alpha_prior + s2, beta_prior + n2 - s2)
# Sample from posteriors
n_samples = 100000
control_samples = post_control.rvs(n_samples)
treatment_samples = post_treatment.rvs(n_samples)
prob_better = np.mean(treatment_samples > control_samples)
lift_samples = (treatment_samples - control_samples) / control_samples
return {
"method": "bayesian",
"control_mean": s1/n1,
"treatment_mean": s2/n2,
"prob_better": prob_better,
"lift_mean": np.mean(lift_samples),
"lift_hdi": (np.percentile(lift_samples, 2.5), np.percentile(lift_samples, 97.5))
}
# Usage
engine = StatisticalEngine(ExperimentConfig(alpha=0.05, mde=0.03))
sample_size = engine.compute_sample_size(baseline_mean=0.1, baseline_var=0.09)
print(f"Required sample size per group: {sample_size}")
control = np.random.binomial(1, 0.1, 5000)
treatment = np.random.binomial(1, 0.115, 5000)
result = engine.analyze(control, treatment)
print(f"Lift: {result['lift']:.3f}, Significant: {result.get('significant', result.get('prob_better', 0) > 0.95)}")
Governance and Reporting
from datetime import datetime
from dataclasses import dataclass
from typing import List
@dataclass
class ExperimentDecision:
experiment_id: str
decision: str # "ship", "iterate", "kill"
reason: str
reviewer: str
timestamp: datetime = None
class ExperimentGovernance:
def __init__(self):
self.decisions: List[ExperimentDecision] = []
def review_experiment(self, experiment_id, results, config):
"""Automated pre-review before human decision"""
checks = {
"sample_size_adequate": results.get('n_control', 0) >= config.min_sample_size,
"duration_adequate": results.get('duration_days', 0) >= config.min_duration,
"no_srm": self._check_sample_ratio_mismatch(results),
"guardrails_green": self._check_guardrails(results, config.guardrails),
"statistical_significance": results.get('p_value', 1) < config.alpha
}
all_pass = all(checks.values())
return {
"checks": checks,
"recommendation": "ready_for_review" if all_pass else "needs_attention",
"summary": self._generate_summary(results, checks)
}
def _check_sample_ratio_mismatch(self, results):
"""SRM check - are groups actually balanced?"""
expected_ratio = results.get('expected_ratio', 0.5)
actual_ratio = results.get('n_treatment', 0) / (results.get('n_control', 1) + results.get('n_treatment', 0))
chi2, p_value = stats.chisquare(
[results['n_control'], results['n_treatment']],
[expected_ratio * sum([results['n_control'], results['n_treatment']]),
(1-expected_ratio) * sum([results['n_control'], results['n_treatment']])]
)
return p_value > 0.01
def _check_guardrails(self, results, guardrails):
for metric, threshold in guardrails.items():
if results.get(metric, 0) < threshold:
return False
return True
def _generate_summary(self, results, checks):
return f"Experiment {results.get('id', 'unknown')}: {'All checks passed' if all(checks.values()) else 'Issues detected'}"
Key Takeaways
- Hash-based assignment ensures consistent user experience
- Gradual rollouts catch issues before full deployment
- Statistical engines should support both frequentist and Bayesian methods
- Governance prevents bad decisions even with good data
- Automate pre-review to focus human attention on judgment calls