ML Feature Store
Feast + Spark + Redis for Real-Time Feature Serving
ℹ️
Project Difficulty: Advanced | Duration: 3-4 weeks | Cloud: AWS/GCP Build a production ML feature store to manage, serve, and monitor features for machine learning models with sub-millisecond latency.
Project Overview
Problem Statement
ML teams struggle with feature engineering duplication, training-serving skew, and slow feature retrieval. Without a centralized feature store, models take weeks to deploy and suffer from inconsistent feature calculations.
Objectives
- Centralize feature definitions and management
- Enable sub-millisecond feature serving for online predictions
- Ensure training-serving consistency (no skew)
- Support batch and real-time feature computation
- Provide feature monitoring and drift detection
Tech Stack
| Component | Technology | Purpose |
|---|---|---|
| Feature Store | Feast | Feature management |
| Compute | Apache Spark | Batch feature engineering |
| Online Store | Redis | Low-latency feature serving |
| Offline Store | Snowflake/BigQuery | Training data |
| Orchestration | Airflow | Feature pipeline scheduling |
Architecture Diagram
Data Source Setup and Schema
Entity and Feature Definitions
# features/definitions.py
from feast import Entity, FeatureView, Field, FileSource
from feast.types import Float32, Int64, String, Bool, Timestamp
from feast.on_demand_feature_view import OnDemandFeatureView
from datetime import timedelta
import pandas as pd
# Entity definitions
customer_entity = Entity(
name="customer_id",
join_keys=["customer_id"],
description="Unique customer identifier",
value_type=String
)
product_entity = Entity(
name="product_id",
join_keys=["product_id"],
description="Unique product identifier",
value_type=String
)
transaction_entity = Entity(
name="transaction_id",
join_keys=["transaction_id"],
description="Unique transaction identifier",
value_type=String
)
# Batch Feature Views - Customer Features
customer_features = FeatureView(
name="customer_features",
entities=[customer_entity],
ttl=timedelta(days=1),
schema=[
Field(name="lifetime_value", dtype=Float32),
Field(name="total_orders", dtype=Int64),
Field(name="avg_order_value", dtype=Float32),
Field(name="days_since_last_order", dtype=Int64),
Field(name="customer_segment", dtype=String),
Field(name="preferred_category", dtype=String),
Field(name="total_spend_last_30d", dtype=Float32),
Field(name="avg_session_duration", dtype=Float32),
Field(name="preferred_payment_method", dtype=String),
Field name="is_active", dtype=Bool),
],
source=FileSource(
path="s3://feature-store/customer_features.parquet",
event_timestamp_column="event_timestamp",
created_timestamp_column="created_timestamp"
),
tags={"team": "customer_analytics", "layer": "batch"}
)
# Batch Feature Views - Product Features
product_features = FeatureView(
name="product_features",
entities=[product_entity],
ttl=timedelta(days=7),
schema=[
Field(name="price", dtype=Float32),
Field(name="category", dtype=String),
Field(name="subcategory", dtype=String),
Field(name="avg_rating", dtype=Float32),
Field(name="review_count", dtype=Int64),
Field(name="sales_last_30d", dtype=Int64),
Field(name="inventory_level", dtype=Int64),
Field(name="profit_margin", dtype=Float32),
Field(name="is_bestseller", dtype=Bool),
Field(name="price_percentile", dtype=Float32),
],
source=FileSource(
path="s3://feature-store/product_features.parquet",
event_timestamp_column="event_timestamp"
),
tags={"team": "product_analytics", "layer": "batch"}
)
# Streaming Feature Views - Real-time Customer Activity
customer_realtime_features = FeatureView(
name="customer_realtime_features",
entities=[customer_entity],
ttl=timedelta(minutes=5),
schema=[
Field(name="clicks_last_5m", dtype=Int64),
Field(name="cart_additions_last_5m", dtype=Int64),
Field(name="session_duration_current", dtype=Float32),
Field(name="pages_viewed_current", dtype=Int64),
Field(name="last_action_type", dtype=String),
Field(name="is_cart_abandoner", dtype=Bool),
],
source=FileSource(
path="s3://feature-store/customer_realtime_features.parquet",
event_timestamp_column="event_timestamp"
),
tags={"team": "realtime", "layer": "streaming"}
)
# On-Demand Feature Views - Computed Features
@OnDemandFeatureView(
inputs={
"customer_features": customer_features,
"product_features": product_features
},
schema=[
Field(name="customer_product_affinity", dtype=Float32),
Field(name="price_sensitivity_score", dtype=Float32),
Field(name="cross_sell_recommendation", dtype=String),
],
description="Computed features combining customer and product data",
tags={"team": "ml", "layer": "computed"}
)
def compute_on_demand_features(features_df: pd.DataFrame) -> pd.DataFrame:
"""Compute on-demand features from batch features."""
result = pd.DataFrame()
# Customer-product affinity (simplified)
result["customer_product_affinity"] = (
features_df["lifetime_value"] * features_df["avg_rating"] / 1000
)
# Price sensitivity score
result["price_sensitivity_score"] = (
1 - (features_df["avg_order_value"] / features_df["price"])
).clip(0, 1)
# Cross-sell recommendation
result["cross_sell_recommendation"] = features_df.apply(
lambda row: "premium" if row["customer_segment"] == "VIP" else "standard",
axis=1
)
return result
Spark Feature Engineering Pipeline
# features/spark_engineering.py
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from datetime import datetime, timedelta
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SparkFeatureEngineer:
def __init__(self, spark: SparkSession, config: dict):
self.spark = spark
self.config = config
self.window_specs = self._define_window_specifications()
def _define_window_specifications(self) -> dict:
"""Define common window specifications."""
return {
'daily': Window.partitionBy('customer_id').orderBy(
F.col('event_date').cast('long')
).rangeBetween(-86400, 0),
'weekly': Window.partitionBy('customer_id').orderBy(
F.col('event_date').cast('long')
).rangeBetween(-604800, 0),
'monthly': Window.partitionBy('customer_id').orderBy(
F.col('event_date').cast('long')
).rangeBetween(-2592000, 0),
'product_daily': Window.partitionBy('product_id').orderBy(
F.col('event_date').cast('long')
).rangeBetween(-86400, 0)
}
def compute_customer_features(self, transactions_df, clickstream_df):
"""Compute comprehensive customer features."""
logger.info("Computing customer features...")
# Transaction-based features
transaction_features = transactions_df \
.withColumn('event_date', F.to_date('order_date')) \
.groupBy('customer_id', 'event_date') \
.agg(
F.sum('amount').alias('daily_spend'),
F.count('order_id').alias('daily_orders'),
F.countDistinct('product_id').alias('daily_unique_products')
)
# Rolling window features
rolling_features = transaction_features \
.withColumn('spend_last_7d', F.sum('daily_spend').over(
Window.partitionBy('customer_id').orderBy('event_date')
.rangeBetween(-604800, 0)
)) \
.withColumn('spend_last_30d', F.sum('daily_spend').over(
Window.partitionBy('customer_id').orderBy('event_date')
.rangeBetween(-2592000, 0)
)) \
.withColumn('orders_last_7d', F.sum('daily_orders').over(
Window.partitionBy('customer_id').orderBy('event_date')
.rangeBetween(-604800, 0)
)) \
.withColumn('avg_order_value_7d', F.col('spend_last_7d') / F.col('orders_last_7d'))
# Lifetime features
lifetime_features = transactions_df \
.groupBy('customer_id') \
.agg(
F.sum('amount').alias('lifetime_value'),
F.count('order_id').alias('total_orders'),
F.min('order_date').alias('first_order_date'),
F.max('order_date').alias('last_order_date'),
F.countDistinct(F.to_date('order_date')).alias('active_days')
) \
.withColumn('days_since_last_order',
F.datediff(F.current_date(), F.col('last_order_date'))) \
.withColumn('customer_tenure_days',
F.datediff(F.current_date(), F.col('first_order_date'))) \
.withColumn('order_frequency',
F.col('total_orders') / (F.col('customer_tenure_days') / 30))
# Customer segmentation
segmented_customers = lifetime_features \
.withColumn('customer_segment',
F.when(F.col('lifetime_value') > 10000, 'VIP')
.when(F.col('lifetime_value') > 5000, 'Premium')
.when(F.col('lifetime_value') > 1000, 'Regular')
.otherwise('New')
)
# Clickstream features
clickstream_features = clickstream_df \
.withColumn('event_date', F.to_date('event_timestamp')) \
.groupBy('user_id', 'event_date') \
.agg(
F.count('*').alias('daily_events'),
F.countDistinct('session_id').alias('daily_sessions'),
F.count(F.when(F.col('event_type') == 'page_view', True)).alias('page_views'),
F.count(F.when(F.col('event_type') == 'add_to_cart', True)).alias('cart_adds'),
F.count(F.when(F.col('event_type') == 'purchase', True)).alias('purchases')
) \
.withColumn('conversion_rate',
F.col('purchases') / F.col('page_views')) \
.withColumn('cart_abandonment_rate',
1 - (F.col('purchases') / F.col('cart_adds')))
# Combine all features
final_features = segmented_customers \
.join(rolling_features.select('customer_id', 'event_date',
'spend_last_7d', 'spend_last_30d',
'orders_last_7d', 'avg_order_value_7d'),
'customer_id', 'left') \
.join(clickstream_features.select('user_id', 'event_date',
'daily_events', 'daily_sessions',
'page_views', 'cart_adds'),
segmented_customers['customer_id'] == clickstream_features['user_id'],
'left')
return final_features
def compute_product_features(self, products_df, sales_df, reviews_df):
"""Compute comprehensive product features."""
logger.info("Computing product features...")
# Sales-based features
sales_features = sales_df \
.withColumn('event_date', F.to_date('order_date')) \
.groupBy('product_id', 'event_date') \
.agg(
F.sum('quantity').alias('daily_sales'),
F.sum('amount').alias('daily_revenue'),
F.countDistinct('order_id').alias('daily_orders')
)
# Rolling sales features
rolling_sales = sales_features \
.withColumn('sales_last_7d', F.sum('daily_sales').over(
Window.partitionBy('product_id').orderBy('event_date')
.rangeBetween(-604800, 0)
)) \
.withColumn('sales_last_30d', F.sum('daily_sales').over(
Window.partitionBy('product_id').orderBy('event_date')
.rangeBetween(-2592000, 0)
)) \
.withColumn('revenue_last_30d', F.sum('daily_revenue').over(
Window.partitionBy('product_id').orderBy('event_date')
.rangeBetween(-2592000, 0)
))
# Review features
review_features = reviews_df \
.groupBy('product_id') \
.agg(
F.avg('rating').alias('avg_rating'),
F.count('*').alias('review_count'),
F.count(F.when(F.col('rating') >= 4, True)).alias('positive_reviews'),
F.count(F.when(F.col('rating') <= 2, True)).alias('negative_reviews')
) \
.withColumn('positive_review_ratio',
F.col('positive_reviews') / F.col('review_count'))
# Product popularity features
popularity_features = sales_df \
.withColumn('event_date', F.to_date('order_date')) \
.groupBy('product_id', 'event_date') \
.agg(
F.countDistinct('customer_id').alias('unique_buyers'),
F.count('*').alias('total_purchases')
) \
.withColumn('popularity_score',
F.col('unique_buyers') * F.col('total_purchases'))
# Combine product features
final_product_features = products_df \
.join(rolling_sales.select('product_id', 'event_date',
'sales_last_7d', 'sales_last_30d',
'revenue_last_30d'),
'product_id', 'left') \
.join(review_features, 'product_id', 'left') \
.join(popularity_features.select('product_id', 'event_date',
'popularity_score'),
'product_id', 'left')
return final_product_features
def compute_realtime_features(self, clickstream_batch_df,
customer_history_df):
"""Compute real-time features from streaming data."""
logger.info("Computing real-time features...")
# Current session features
session_features = clickstream_batch_df \
.groupBy('user_id', 'session_id') \
.agg(
F.count('*').alias('session_events'),
F.countDistinct('page_url').alias('unique_pages'),
F.min('event_timestamp').alias('session_start'),
F.max('event_timestamp').alias('session_end'),
F.count(F.when(F.col('event_type') == 'add_to_cart', True))
.alias('cart_additions')
) \
.withColumn('session_duration_seconds',
F.unix_timestamp('session_end') - F.unix_timestamp('session_start'))
# Real-time aggregates (last 5 minutes)
realtime_agg = clickstream_batch_df \
.withWatermark('event_timestamp', '2 minutes') \
.groupBy(
F.window('event_timestamp', '5 minutes'),
'user_id'
) \
.agg(
F.count('*').alias('events_last_5m'),
F.count(F.when(F.col('event_type') == 'page_view', True))
.alias('page_views_last_5m'),
F.count(F.when(F.col('event_type') == 'add_to_cart', True))
.alias('cart_adds_last_5m')
)
# Merge with historical features
realtime_features = realtime_agg \
.join(customer_history_df, 'user_id', 'left') \
.withColumn('engagement_score',
F.col('events_last_5m') * 0.3 +
F.col('page_views_last_5m') * 0.4 +
F.col('cart_adds_last_5m') * 0.3
)
return realtime_features
def save_features_to_store(self, features_df, feature_name: str,
output_path: str):
"""Save computed features to the feature store."""
logger.info(f"Saving {feature_name} features to {output_path}")
# Add metadata columns
features_with_metadata = features_df \
.withColumn('event_timestamp', F.current_timestamp()) \
.withColumn('created_timestamp', F.current_timestamp()) \
.withColumn('feature_version', F.lit('v1.0'))
# Write to data lake in Parquet format
features_with_metadata.write \
.mode('overwrite') \
.partitionBy('event_date') \
.parquet(output_path)
# Register with Feast (would call Feast API)
logger.info(f"Registered {feature_name} with Feast feature store")
def run_full_feature_pipeline(self):
"""Run the complete feature engineering pipeline."""
logger.info("Starting full feature engineering pipeline")
# Read source data
transactions_df = self.spark.read.format('jdbc') \
.option('url', self.config['db_url']) \
.option('dbtable', 'public.orders') \
.load()
clickstream_df = self.spark.read.parquet(
self.config['clickstream_path']
)
products_df = self.spark.read.format('jdbc') \
.option('url', self.config['db_url']) \
.option('dbtable', 'public.products') \
.load()
sales_df = self.spark.read.format('jdbc') \
.option('url', self.config['db_url']) \
.option('dbtable', 'public.order_items') \
.load()
reviews_df = self.spark.read.parquet(
self.config['reviews_path']
)
# Compute features
customer_features = self.compute_customer_features(
transactions_df, clickstream_df
)
product_features = self.compute_product_features(
products_df, sales_df, reviews_df
)
# Save features
self.save_features_to_store(
customer_features,
'customer_features',
f"{self.config['feature_store_path']}/customer_features"
)
self.save_features_to_store(
product_features,
'product_features',
f"{self.config['feature_store_path']}/product_features"
)
logger.info("Feature engineering pipeline completed")
Step-by-Step Implementation Guide
Step 1: Feast Configuration
# feast/feature_store.yaml
project: ecommerce_feature_store
registry: s3://feature-store/registry/registry.db
provider: aws
online_store:
type: redis
connection_string: "redis-cluster.internal:6379,redis-cluster.internal:6380,redis-cluster.internal:6381"
ssl: true
offline_store:
type: snowflake
account: ${SNOWFLAKE_ACCOUNT}
user: ${SNOWFLAKE_USER}
password: ${SNOWFLAKE_PASSWORD}
role: FEATURE_STORE_ROLE
warehouse: FEATURE_STORE_WH
database: ECOMMERCE
schema: FEATURES
entity_key_serialization_version: 2
feature_server:
host: 0.0.0.0
port: 6566
enable_prometheus: true
logging:
level: INFO
path: /var/log/feast
Step 2: Feature Serving API
# serving/feature_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import feast
import redis
import json
from datetime import datetime
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Feature Store API",
description="Real-time feature serving for ML models",
version="1.0.0"
)
class FeatureRequest(BaseModel):
entity_rows: List[Dict[str, Any]]
feature_refs: List[str]
full_feature_names: bool = False
class FeatureResponse(BaseModel):
metadata: Dict[str, str]
results: List[Dict[str, Any]]
class FeatureStoreService:
def __init__(self):
self.store = feast.FeatureStore(repo_path="/opt/feast/feature_store.yaml")
self.redis_client = redis.Redis(
host='redis-cluster.internal',
port=6379,
decode_responses=True,
ssl=True
)
self.cache_ttl = 300 # 5 minutes
def get_online_features(self, entity_rows: List[Dict],
feature_refs: List[str]) -> Dict:
"""Get features from online store with caching."""
# Check cache first
cache_key = self._generate_cache_key(entity_rows, feature_refs)
cached = self.redis_client.get(cache_key)
if cached:
logger.info("Returning cached features")
return json.loads(cached)
# Get from Feast online store
response = self.store.get_online_features(
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=True
)
result = {
'metadata': {
'feature_refs': feature_refs,
'entity_count': len(entity_rows),
'timestamp': datetime.utcnow().isoformat()
},
'results': response.to_dict()
}
# Cache the result
self.redis_client.setex(
cache_key,
self.cache_ttl,
json.dumps(result)
)
return result
def get_historical_features(self, entity_df: pd.DataFrame,
feature_refs: List[str]) -> pd.DataFrame:
"""Get historical features for training."""
return self.store.get_historical_features(
entity_df=entity_df,
features=feature_refs
).to_df()
def _generate_cache_key(self, entity_rows: List[Dict],
feature_refs: List[str]) -> str:
"""Generate cache key for feature request."""
import hashlib
key_data = {
'entities': sorted(str(row) for row in entity_rows),
'features': sorted(feature_refs)
}
key_hash = hashlib.md5(
json.dumps(key_data, sort_keys=True).encode()
).hexdigest()
return f"features:{key_hash}"
# Initialize service
feature_service = FeatureStoreService()
@app.post("/get-features", response_model=FeatureResponse)
async def get_features(request: FeatureRequest):
"""Get real-time features for online serving."""
try:
result = feature_service.get_online_features(
entity_rows=request.entity_rows,
feature_refs=request.feature_refs
)
return FeatureResponse(**result)
except Exception as e:
logger.error(f"Feature retrieval failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/feature-views")
async def list_feature_views():
"""List all available feature views."""
feature_views = feature_service.store.list_feature_views()
return {
'feature_views': [
{
'name': fv.name,
'entities': fv.entities,
'ttl': str(fv.ttl),
'features': [f.name for f in fv.features]
}
for fv in feature_views
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
'status': 'healthy',
'timestamp': datetime.utcnow().isoformat(),
'redis_connected': feature_service.redis_client.ping()
}
@app.get("/metrics")
async def get_metrics():
"""Get feature store metrics."""
# Prometheus metrics would be exposed here
return {
'requests_total': 0,
'cache_hit_rate': 0.0,
'avg_latency_ms': 0.0
}
Step 3: Feature Monitoring
# monitoring/feature_monitor.py
import pandas as pd
import numpy as np
from scipy import stats
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import logging
logging.basicConfig(level.INFO)
logger = logging.getLogger(__name__)
class FeatureMonitor:
def __init__(self, feature_store, alert_callback=None):
self.feature_store = feature_store
self.alert_callback = alert_callback
self.baseline_stats = {}
def compute_feature_statistics(self, features_df: pd.DataFrame) -> Dict:
"""Compute comprehensive statistics for features."""
stats = {}
for column in features_df.select_dtypes(include=[np.number]).columns:
col_stats = {
'mean': float(features_df[column].mean()),
'std': float(features_df[column].std()),
'min': float(features_df[column].min()),
'max': float(features_df[column].max()),
'median': float(features_df[column].median()),
'null_count': int(features_df[column].isnull().sum()),
'null_ratio': float(features_df[column].isnull().mean()),
'skewness': float(features_df[column].skew()),
'kurtosis': float(features_df[column].kurtosis()),
'percentile_25': float(features_df[column].quantile(0.25)),
'percentile_75': float(features_df[column].quantile(0.75)),
'iqr': float(features_df[column].quantile(0.75) - features_df[column].quantile(0.25))
}
# Detect outliers using IQR method
q1 = features_df[column].quantile(0.25)
q3 = features_df[column].quantile(0.75)
iqr = q3 - q1
outlier_count = int(((features_df[column] < q1 - 1.5 * iqr) |
(features_df[column] > q3 + 1.5 * iqr)).sum())
col_stats['outlier_count'] = outlier_count
col_stats['outlier_ratio'] = outlier_count / len(features_df)
stats[column] = col_stats
return stats
def detect_data_drift(self, current_stats: Dict,
baseline_stats: Dict,
threshold: float = 0.1) -> List[Dict]:
"""Detect data drift using statistical tests."""
drift_alerts = []
for feature_name, current in current_stats.items():
if feature_name not in baseline_stats:
continue
baseline = baseline_stats[feature_name]
# Kolmogorov-Smirnov test for distribution drift
ks_stat, p_value = stats.ks_2samp(
[current['mean']],
[baseline['mean']]
)
# Population Stability Index (PSI)
psi = self._calculate_psi(
baseline['mean'],
current['mean'],
baseline['std']
)
if p_value < 0.05 or psi > threshold:
drift_alerts.append({
'feature': feature_name,
'drift_type': 'distribution',
'p_value': p_value,
'psi': psi,
'baseline_mean': baseline['mean'],
'current_mean': current['mean'],
'severity': 'high' if psi > threshold * 2 else 'medium'
})
# Null ratio drift
null_ratio_change = abs(current['null_ratio'] - baseline['null_ratio'])
if null_ratio_change > 0.05:
drift_alerts.append({
'feature': feature_name,
'drift_type': 'null_ratio',
'baseline_null_ratio': baseline['null_ratio'],
'current_null_ratio': current['null_ratio'],
'severity': 'medium'
})
return drift_alerts
def _calculate_psi(self, baseline_mean: float, current_mean: float,
baseline_std: float) -> float:
"""Calculate Population Stability Index."""
# Simplified PSI calculation
if baseline_std == 0:
return 0
psi = abs(current_mean - baseline_mean) / baseline_std
return psi
def validate_feature_freshness(self, feature_view: str,
max_age_seconds: int = 300) -> Dict:
"""Validate feature freshness and timeliness."""
# Get latest feature timestamp
latest_features = self.feature_store.get_online_features(
features=[f"{feature_view}:event_timestamp"],
entity_rows=[{'customer_id': 'latest_check'}]
)
if latest_features:
latest_timestamp = datetime.fromisoformat(
latest_features['metadata']['timestamp']
)
age_seconds = (datetime.utcnow() - latest_timestamp).total_seconds()
is_fresh = age_seconds <= max_age_seconds
return {
'feature_view': feature_view,
'latest_timestamp': latest_timestamp.isoformat(),
'age_seconds': age_seconds,
'is_fresh': is_fresh,
'max_age_seconds': max_age_seconds
}
return {
'feature_view': feature_view,
'is_fresh': False,
'error': 'No features found'
}
def monitor_feature_quality(self, features_df: pd.DataFrame) -> Dict:
"""Comprehensive feature quality monitoring."""
quality_report = {
'timestamp': datetime.utcnow().isoformat(),
'total_features': len(features_df.columns),
'total_rows': len(features_df),
'issues': []
}
for column in features_df.columns:
feature_issues = []
# Check for high null ratio
null_ratio = features_df[column].isnull().mean()
if null_ratio > 0.1:
feature_issues.append({
'type': 'high_null_ratio',
'value': float(null_ratio),
'threshold': 0.1
})
# Check for constant features
if features_df[column].nunique() == 1:
feature_issues.append({
'type': 'constant_feature',
'unique_values': 1
})
# Check for high cardinality in categorical features
if features_df[column].dtype == 'object':
cardinality = features_df[column].nunique()
if cardinality > 100:
feature_issues.append({
'type': 'high_cardinality',
'cardinality': cardinality,
'threshold': 100
})
# Check for numeric outliers
if pd.api.types.is_numeric_dtype(features_df[column]):
q1 = features_df[column].quantile(0.25)
q3 = features_df[column].quantile(0.75)
iqr = q3 - q1
outlier_ratio = ((features_df[column] < q1 - 3 * iqr) |
(features_df[column] > q3 + 3 * iqr)).mean()
if outlier_ratio > 0.01:
feature_issues.append({
'type': 'extreme_outliers',
'ratio': float(outlier_ratio),
'threshold': 0.01
})
if feature_issues:
quality_report['issues'].append({
'feature': column,
'issues': feature_issues
})
quality_report['features_with_issues'] = len(quality_report['issues'])
quality_report['quality_score'] = (
1 - (quality_report['features_with_issues'] /
quality_report['total_features'])
) * 100
return quality_report
def generate_monitoring_report(self) -> Dict:
"""Generate comprehensive monitoring report."""
report = {
'timestamp': datetime.utcnow().isoformat(),
'feature_views': {}
}
# Monitor each feature view
feature_views = self.feature_store.list_feature_views()
for fv in feature_views:
logger.info(f"Monitoring feature view: {fv.name}")
# Get sample data
sample_entities = [{'customer_id': f'customer_{i}'}
for i in range(1000)]
features = self.feature_store.get_online_features(
features=[f"{fv.name}:{f.name}" for f in fv.features],
entity_rows=sample_entities
)
if features:
features_df = pd.DataFrame(features['results'])
# Compute statistics
stats = self.compute_feature_statistics(features_df)
# Check freshness
freshness = self.validate_feature_freshness(fv.name)
# Check quality
quality = self.monitor_feature_quality(features_df)
report['feature_views'][fv.name] = {
'statistics': stats,
'freshness': freshness,
'quality': quality
}
# Check for drift if baseline exists
if self.baseline_stats:
current_stats = {}
for fv_name, fv_report in report['feature_views'].items():
current_stats[fv_name] = fv_report['statistics']
drift_alerts = self.detect_data_drift(
current_stats,
self.baseline_stats
)
report['drift_alerts'] = drift_alerts
# Store baseline for future comparison
self.baseline_stats = {
fv_name: fv_report['statistics']
for fv_name, fv_report in report['feature_views'].items()
}
return report
Infrastructure Setup (Terraform)
# infrastructure/ml_feature_store.tf
terraform {
required_version = ">= 1.5.0"
required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.0"
}
}
}
# Redis Cluster for Online Store
resource "aws_elasticache_replication_group" "feature_redis" {
replication_group_id = "ml-feature-store-redis"
description = "Redis cluster for ML feature store online serving"
node_type = "cache.r6g.xlarge"
num_cache_clusters = 6
engine = "redis"
engine_version = "7.0"
port = 6379
subnet_group_name = aws_elasticache_subnet_group.redis.name
security_group_ids = [aws_security_group.redis.id]
at_rest_encryption_enabled = true
transit_encryption_enabled = true
automatic_failover_enabled = true
multi_az_enabled = true
snapshot_retention_limit = 7
snapshot_window = "03:00-04:00"
tags = {
Project = "ml-feature-store"
Purpose = "online-feature-serving"
}
}
# ElastiCache Subnet Group
resource "aws_elasticache_subnet_group" "redis" {
name = "ml-feature-redis-subnet"
subnet_ids = module.vpc.private_subnets
}
# Security Group for Redis
resource "aws_security_group" "redis" {
name_prefix = "ml-feature-redis-"
vpc_id = module.vpc.vpc_id
ingress {
from_port = 6379
to_port = 6379
protocol = "tcp"
cidr_blocks = [module.vpc.vpc_cidr_block]
description = "Redis"
}
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}
}
# S3 Bucket for Feature Store
resource "aws_s3_bucket" "feature_store" {
bucket = "ml-feature-store-${var.environment}"
tags = {
Project = "ml-feature-store"
Purpose = "feature-storage"
}
}
resource "aws_s3_bucket_versioning" "feature_store" {
bucket = aws_s3_bucket.feature_store.id
versioning_configuration {
status = "Enabled"
}
}
resource "aws_s3_bucket_server_side_encryption_configuration" "feature_store" {
bucket = aws_s3_bucket.feature_store.id
rule {
apply_server_side_encryption_by_default {
sse_algorithm = "aws:kms"
}
}
}
# EMR Cluster for Feature Engineering
resource "aws_emr_cluster" "feature_engineering" {
name = "ml-feature-engineering"
release_label = "emr-6.15.0"
applications = ["Spark", "Hive"]
master_instance_group {
instance_type = "m5.xlarge"
instance_count = 1
}
core_instance_group {
instance_type = "r5.2xlarge"
instance_count = 3
ebs_config {
type = "gp3"
size = 100
volumes_per_instance = 1
}
}
ec2_attributes {
instance_profile = aws_iam_instance_profile.emr.arn
subnet_id = module.vpc.private_subnets[0]
}
service_role = aws_iam_role.emr_service.arn
tags = {
Project = "ml-feature-store"
Purpose = "feature-engineering"
}
}
# ECS Cluster for Feature Server
resource "aws_ecs_cluster" "feature_server" {
name = "ml-feature-server"
setting {
name = "containerInsights"
value = "enabled"
}
tags = {
Project = "ml-feature-store"
Purpose = "feature-serving"
}
}
# ECS Task Definition for Feature Server
resource "aws_ecs_task_definition" "feature_server" {
family = "feature-server"
network_mode = "awsvpc"
requires_compatibilities = ["FARGATE"]
cpu = "2048"
memory = "4096"
container_definitions = jsonencode([
{
name = "feature-server"
image = "${var.aws_account_id}.dkr.ecr.${var.aws_region}.amazonaws.com/feature-server:latest"
portMappings = [
{
containerPort = 6566
hostPort = 6566
protocol = "tcp"
}
]
environment = [
{
name = "FEAST_CONFIG_PATH"
value = "/opt/feast/feature_store.yaml"
},
{
name = "REDIS_HOST"
value = aws_elasticache_replication_group.feature_redis.primary_endpoint_address
}
]
logConfiguration = {
logDriver = "awslogs"
options = {
"awslogs-group" = aws_cloudwatch_log_group.feature_server.name
"awslogs-region" = var.aws_region
"awslogs-stream-prefix" = "feature-server"
}
}
}
])
execution_role_arn = aws_iam_role.ecs_execution.arn
task_role_arn = aws_iam_role.ecs_task.arn
}
# ECS Service for Feature Server
resource "aws_ecs_service" "feature_server" {
name = "feature-server"
cluster = aws_ecs_cluster.feature_server.id
task_definition = aws_ecs_task_definition.feature_server.arn
desired_count = 3
network_configuration {
subnets = module.vpc.private_subnets
security_groups = [aws_security_group.feature_server.id]
assign_public_ip = false
}
load_balancer {
target_group_arn = aws_lb_target_group.feature_server.arn
container_name = "feature-server"
container_port = 6566
}
depends_on = [aws_lb_listener.feature_server]
}
# Application Load Balancer
resource "aws_lb" "feature_server" {
name = "feature-server-alb"
internal = true
load_balancer_type = "application"
security_groups = [aws_security_group.alb.id]
subnets = module.vpc.private_subnets
tags = {
Project = "ml-feature-store"
}
}
resource "aws_lb_target_group" "feature_server" {
name = "feature-server-tg"
port = 6566
protocol = "HTTP"
vpc_id = module.vpc.vpc_id
target_type = "ip"
health_check {
path = "/health"
port = "traffic-port"
healthy_threshold = 3
unhealthy_threshold = 3
timeout = 5
interval = 30
matcher = "200"
}
}
resource "aws_lb_listener" "feature_server" {
load_balancer_arn = aws_lb.feature_server.arn
port = 80
protocol = "HTTP"
default_action {
type = "forward"
target_group_arn = aws_lb_target_group.feature_server.arn
}
}
# CloudWatch Log Group
resource "aws_cloudwatch_log_group" "feature_server" {
name = "/ecs/feature-server"
retention_in_days = 30
}
# IAM Roles
resource "aws_iam_role" "ecs_execution" {
name = "ecs-feature-server-execution"
assume_role_policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = "sts:AssumeRole"
Effect = "Allow"
Principal = {
Service = "ecs-tasks.amazonaws.com"
}
}
]
})
}
resource "aws_iam_role" "ecs_task" {
name = "ecs-feature-server-task"
assume_role_policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Action = "sts:AssumeRole"
Effect = "Allow"
Principal = {
Service = "ecs-tasks.amazonaws.com"
}
}
]
})
}
resource "aws_iam_role_policy" "ecs_task" {
name = "ecs-feature-server-task-policy"
role = aws_iam_role.ecs_task.id
policy = jsonencode({
Version = "2012-10-17"
Statement = [
{
Effect = "Allow"
Action = [
"elasticache:*",
"s3:GetObject",
"s3:PutObject"
]
Resource = "*"
}
]
})
}
# Outputs
output "redis_endpoint" {
description = "Redis endpoint for online feature serving"
value = aws_elasticache_replication_group.feature_redis.primary_endpoint_address
}
output "feature_store_bucket" {
description = "S3 bucket for feature store"
value = aws_s3_bucket.feature_store.bucket
}
output "feature_server_dns" {
description = "Feature server DNS name"
value = aws_lb.feature_server.dns_name
}
Testing and Validation
# tests/test_feature_store.py
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
class TestFeatureStore:
@pytest.fixture
def sample_features(self):
return pd.DataFrame({
'customer_id': [f'customer_{i}' for i in range(1000)],
'lifetime_value': np.random.uniform(0, 10000, 1000),
'total_orders': np.random.randint(1, 50, 1000),
'avg_order_value': np.random.uniform(10, 500, 1000),
'days_since_last_order': np.random.randint(0, 365, 1000),
'customer_segment': np.random.choice(['VIP', 'Premium', 'Regular', 'New'], 1000)
})
@pytest.fixture
def feature_monitor(self):
from monitoring.feature_monitor import FeatureMonitor
mock_store = Mock()
return FeatureMonitor(mock_store)
def test_feature_statistics(self, feature_monitor, sample_features):
"""Test feature statistics computation."""
stats = feature_monitor.compute_feature_statistics(sample_features)
assert 'lifetime_value' in stats
assert 'total_orders' in stats
# Check statistics are computed
assert 'mean' in stats['lifetime_value']
assert 'std' in stats['lifetime_value']
assert 'null_count' in stats['lifetime_value']
def test_data_drift_detection(self, feature_monitor):
"""Test data drift detection."""
baseline_stats = {
'feature_1': {
'mean': 100.0,
'std': 10.0,
'null_ratio': 0.01
}
}
current_stats = {
'feature_1': {
'mean': 120.0,
'std': 15.0,
'null_ratio': 0.05
}
}
drift_alerts = feature_monitor.detect_data_drift(
current_stats, baseline_stats
)
assert len(drift_alerts) > 0
assert drift_alerts[0]['feature'] == 'feature_1'
def test_feature_freshness(self, feature_monitor):
"""Test feature freshness validation."""
with patch.object(feature_monitor.feature_store, 'get_online_features') as mock:
mock.return_value = {
'metadata': {
'timestamp': datetime.utcnow().isoformat()
}
}
freshness = feature_monitor.validate_feature_freshness(
'test_feature_view',
max_age_seconds=300
)
assert freshness['is_fresh'] == True
def test_feature_quality(self, feature_monitor):
"""Test feature quality monitoring."""
# Create features with quality issues
problematic_features = pd.DataFrame({
'good_feature': np.random.uniform(0, 1, 1000),
'high_null_feature': [np.nan] * 500 + np.random.uniform(0, 1, 500),
'constant_feature': [1.0] * 1000,
'high_cardinality': [f'value_{i}' for i in range(1000)]
})
quality_report = feature_monitor.monitor_feature_quality(
problematic_features
)
assert quality_report['features_with_issues'] > 0
assert quality_report['quality_score'] < 100
def test_online_feature_serving(self):
"""Test online feature serving API."""
from serving.feature_server import FeatureStoreService
mock_store = Mock()
mock_store.get_online_features.return_value = Mock(
to_dict=lambda: {
'results': [
{'customer_id': 'customer_1', 'lifetime_value': 5000.0}
]
}
)
service = FeatureStoreService()
service.store = mock_store
result = service.get_online_features(
entity_rows=[{'customer_id': 'customer_1'}],
feature_refs=['customer_features:lifetime_value']
)
assert 'metadata' in result
assert 'results' in result
def test_feature_cache(self):
"""Test feature caching in Redis."""
import json
from unittest.mock import Mock
mock_redis = Mock()
mock_redis.get.return_value = None
mock_redis.setex.return_value = True
cache_key = "features:test_hash"
cache_ttl = 300
# Simulate cache miss
cached = mock_redis.get(cache_key)
assert cached is None
# Simulate cache set
feature_data = {'test': 'data'}
mock_redis.setex(cache_key, cache_ttl, json.dumps(feature_data))
mock_redis.setex.assert_called_once_with(
cache_key, cache_ttl, json.dumps(feature_data)
)
def test_historical_features(self):
"""Test historical feature retrieval for training."""
from datetime import datetime
# Create entity dataframe
entity_df = pd.DataFrame({
'customer_id': [f'customer_{i}' for i in range(100)],
'event_date': [datetime(2024, 1, 15)] * 100
})
# Simulate historical features
historical_features = pd.DataFrame({
'customer_id': [f'customer_{i}' for i in range(100)],
'event_date': [datetime(2024, 1, 15)] * 100,
'lifetime_value': np.random.uniform(0, 10000, 100),
'total_orders': np.random.randint(1, 50, 100)
})
# Verify schema
assert 'customer_id' in historical_features.columns
assert 'lifetime_value' in historical_features.columns
assert len(historical_features) == len(entity_df)
def test_feature_pipeline(self):
"""Test end-to-end feature pipeline."""
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("TestFeaturePipeline") \
.master("local[*]") \
.getOrCreate()
# Create sample transaction data
transactions_data = [
('customer_1', 'order_1', 100.0, '2024-01-15'),
('customer_1', 'order_2', 150.0, '2024-01-16'),
('customer_2', 'order_3', 200.0, '2024-01-15')
]
transactions_df = spark.createDataFrame(
transactions_data,
['customer_id', 'order_id', 'amount', 'order_date']
)
# Compute features
from features.spark_engineering import SparkFeatureEngineer
config = {'db_url': 'jdbc:postgresql://localhost/test'}
engineer = SparkFeatureEngineer(spark, config)
# Simple aggregation test
customer_features = transactions_df \
.groupBy('customer_id') \
.agg(
{'amount': 'sum', 'order_id': 'count'}
)
assert 'customer_id' in customer_features.columns
assert 'sum(amount)' in customer_features.columns
spark.stop()
def test_feature_validation(self):
"""Test feature validation rules."""
validation_rules = [
{
'feature': 'lifetime_value',
'type': 'range',
'min': 0,
'max': 100000
},
{
'feature': 'customer_segment',
'type': 'enum',
'values': ['VIP', 'Premium', 'Regular', 'New']
},
{
'feature': 'total_orders',
'type': 'non_negative'
}
]
# Test data
test_data = {
'lifetime_value': 5000,
'customer_segment': 'VIP',
'total_orders': 10
}
# Validate
for rule in validation_rules:
feature = rule['feature']
value = test_data[feature]
if rule['type'] == 'range':
assert rule['min'] <= value <= rule['max']
elif rule['type'] == 'enum':
assert value in rule['values']
elif rule['type'] == 'non_negative':
assert value >= 0
def test_feature_monitoring_report(self, feature_monitor, sample_features):
"""Test comprehensive monitoring report generation."""
with patch.object(feature_monitor.feature_store, 'list_feature_views') as mock_views:
mock_fv = Mock()
mock_fv.name = 'customer_features'
mock_fv.features = [Mock(name='lifetime_value'), Mock(name='total_orders')]
mock_views.return_value = [mock_fv]
with patch.object(feature_monitor.feature_store, 'get_online_features') as mock_features:
mock_features.return_value = {
'results': sample_features.to_dict('records')
}
report = feature_monitor.generate_monitoring_report()
assert 'feature_views' in report
assert 'customer_features' in report['feature_views']
Cost Analysis
Monthly Cost Breakdown (Production)
| Component | Specification | Monthly Cost |
|---|---|---|
| Redis Cluster | 6x cache.r6g.xlarge | $2,400 |
| EMR Cluster | 3x r5.2xlarge (intermittent) | $800 |
| ECS (Feature Server) | 3x Fargate tasks | $300 |
| S3 Storage | 2TB feature data | $46 |
| Snowflake | Feature store queries | $200 |
| Data Transfer | Cross-AZ | $100 |
| Total | $3,846 |
Cost Optimization Strategies
💡
Tip: Optimize feature store costs with these strategies:
- Feature Caching: 90% cache hit rate reduces compute
- Feature Pruning: Remove low-importance features
- Batch vs Real-time: Use batch for non-critical features
- Spot Instances: Use for EMR feature engineering
- Auto-scaling: Scale Redis based on traffic
Performance Metrics
| Metric | Before Feature Store | After Feature Store | Improvement |
|---|---|---|---|
| Feature Development | 2 weeks | 2 days | 7x faster |
| Model Deployment | 1 week | 1 hour | 168x faster |
| Feature Serving Latency | 50ms | 2ms | 25x faster |
| Training-Serving Skew | 15% | < 1% | 15x better |
Interview Talking Points
Architecture Decisions
ℹ️
Best Practice: Focus on these feature store concepts in interviews:
-
Why Feast over custom feature store?
- Open-source and cloud-agnostic
- Native integration with Spark, Flink
- Online/offline store separation
- Feature versioning and lineage
-
Why Redis for online store?
- Sub-millisecond latency
- Native data structure support
- Cluster mode for scaling
- Pub/Sub for real-time updates
-
How to prevent training-serving skew?
- Same feature definitions for both
- Point-in-time correctness
- Feature versioning and snapshots
- Automated validation tests
Common Interview Questions
Q: "How do you handle feature freshness for real-time models?"
freshness_strategies = {
"Streaming Features": "Flink/Spark Streaming for real-time computation",
"Cache Invalidation": "TTL-based expiration with event-driven updates",
"Hybrid Approach": "Batch for historical, streaming for recent",
"Feature Monitoring": "Alert on staleness and freshness violations"
}
Q: "How do you ensure feature quality at scale?"
quality_framework = {
"Validation Rules": "Schema validation, range checks, null handling",
"Statistical Tests": "Distribution monitoring, drift detection",
"Automated Testing": "CI/CD pipeline for feature definitions",
"Monitoring Dashboards": "Real-time feature health metrics"
}
Q: "How do you version features for reproducibility?"
versioning_strategy = {
"Feature Snapshots": "Point-in-time snapshots for training",
"Git Integration": "Feature definitions in version control",
"Registry": "Central metadata store for feature versions",
"Lineage Tracking": "Track feature dependencies and transformations"
}
Deployment Checklist
- Set up Redis cluster for online store
- Configure Snowflake for offline store
- Deploy Feast feature store
- Set up Spark feature engineering pipeline
- Deploy feature serving API
- Configure feature monitoring
- Set up CI/CD for feature definitions
- Test online/offline consistency
- Document feature catalog
- Train ML team on feature store usage
⚠️
Warning: Always validate feature freshness and quality before serving to production models. Stale or low-quality features can significantly degrade model performance.
This project demonstrates ML infrastructure skills and is highly relevant for ML engineering and data science platform interviews.