🎉 75% of content is free forever — Unlock Premium from $10/mo →
CW
Search courses…
💼 Servicesℹ️ About✉️ ContactView Pricing Plansfrom $10

Building an ML Feature Store (Feast + Spark + Redis)

Data Engineering ProjectsMachine Learning Infrastructure⭐ Premium

Advertisement

ML Feature Store

Feast + Spark + Redis for Real-Time Feature Serving

ℹ️

Project Difficulty: Advanced | Duration: 3-4 weeks | Cloud: AWS/GCP Build a production ML feature store to manage, serve, and monitor features for machine learning models with sub-millisecond latency.

Project Overview

Problem Statement

ML teams struggle with feature engineering duplication, training-serving skew, and slow feature retrieval. Without a centralized feature store, models take weeks to deploy and suffer from inconsistent feature calculations.

Objectives

  1. Centralize feature definitions and management
  2. Enable sub-millisecond feature serving for online predictions
  3. Ensure training-serving consistency (no skew)
  4. Support batch and real-time feature computation
  5. Provide feature monitoring and drift detection

Tech Stack

ComponentTechnologyPurpose
Feature StoreFeastFeature management
ComputeApache SparkBatch feature engineering
Online StoreRedisLow-latency feature serving
Offline StoreSnowflake/BigQueryTraining data
OrchestrationAirflowFeature pipeline scheduling

Architecture Diagram

DATA SOURCESTransaction DB (PostgreSQL)Clickstream (Kafka)External APIs (Weather, News)FEATURE ENGINEERINGBatch Features (Spark)Stream Features (Flink)On-Demand Features (UDFs)FEAST FEATURE STOREFeature Views (Definitions)Entity Definitions (Keys)Feature Services (Serving)STORAGE LAYERSOffline Store (Data Lake)Online Store (Redis)Registry (Metadata)ML WORKLOADSTraining PipelinesBatch Predict JobsReal-time ServingA/B Testing Experiments

Data Source Setup and Schema

Entity and Feature Definitions

# features/definitions.py
from feast import Entity, FeatureView, Field, FileSource
from feast.types import Float32, Int64, String, Bool, Timestamp
from feast.on_demand_feature_view import OnDemandFeatureView
from datetime import timedelta
import pandas as pd

# Entity definitions
customer_entity = Entity(
    name="customer_id",
    join_keys=["customer_id"],
    description="Unique customer identifier",
    value_type=String
)

product_entity = Entity(
    name="product_id",
    join_keys=["product_id"],
    description="Unique product identifier",
    value_type=String
)

transaction_entity = Entity(
    name="transaction_id",
    join_keys=["transaction_id"],
    description="Unique transaction identifier",
    value_type=String
)

# Batch Feature Views - Customer Features
customer_features = FeatureView(
    name="customer_features",
    entities=[customer_entity],
    ttl=timedelta(days=1),
    schema=[
        Field(name="lifetime_value", dtype=Float32),
        Field(name="total_orders", dtype=Int64),
        Field(name="avg_order_value", dtype=Float32),
        Field(name="days_since_last_order", dtype=Int64),
        Field(name="customer_segment", dtype=String),
        Field(name="preferred_category", dtype=String),
        Field(name="total_spend_last_30d", dtype=Float32),
        Field(name="avg_session_duration", dtype=Float32),
        Field(name="preferred_payment_method", dtype=String),
        Field name="is_active", dtype=Bool),
    ],
    source=FileSource(
        path="s3://feature-store/customer_features.parquet",
        event_timestamp_column="event_timestamp",
        created_timestamp_column="created_timestamp"
    ),
    tags={"team": "customer_analytics", "layer": "batch"}
)

# Batch Feature Views - Product Features
product_features = FeatureView(
    name="product_features",
    entities=[product_entity],
    ttl=timedelta(days=7),
    schema=[
        Field(name="price", dtype=Float32),
        Field(name="category", dtype=String),
        Field(name="subcategory", dtype=String),
        Field(name="avg_rating", dtype=Float32),
        Field(name="review_count", dtype=Int64),
        Field(name="sales_last_30d", dtype=Int64),
        Field(name="inventory_level", dtype=Int64),
        Field(name="profit_margin", dtype=Float32),
        Field(name="is_bestseller", dtype=Bool),
        Field(name="price_percentile", dtype=Float32),
    ],
    source=FileSource(
        path="s3://feature-store/product_features.parquet",
        event_timestamp_column="event_timestamp"
    ),
    tags={"team": "product_analytics", "layer": "batch"}
)

# Streaming Feature Views - Real-time Customer Activity
customer_realtime_features = FeatureView(
    name="customer_realtime_features",
    entities=[customer_entity],
    ttl=timedelta(minutes=5),
    schema=[
        Field(name="clicks_last_5m", dtype=Int64),
        Field(name="cart_additions_last_5m", dtype=Int64),
        Field(name="session_duration_current", dtype=Float32),
        Field(name="pages_viewed_current", dtype=Int64),
        Field(name="last_action_type", dtype=String),
        Field(name="is_cart_abandoner", dtype=Bool),
    ],
    source=FileSource(
        path="s3://feature-store/customer_realtime_features.parquet",
        event_timestamp_column="event_timestamp"
    ),
    tags={"team": "realtime", "layer": "streaming"}
)

# On-Demand Feature Views - Computed Features
@OnDemandFeatureView(
    inputs={
        "customer_features": customer_features,
        "product_features": product_features
    },
    schema=[
        Field(name="customer_product_affinity", dtype=Float32),
        Field(name="price_sensitivity_score", dtype=Float32),
        Field(name="cross_sell_recommendation", dtype=String),
    ],
    description="Computed features combining customer and product data",
    tags={"team": "ml", "layer": "computed"}
)
def compute_on_demand_features(features_df: pd.DataFrame) -> pd.DataFrame:
    """Compute on-demand features from batch features."""
    result = pd.DataFrame()
    
    # Customer-product affinity (simplified)
    result["customer_product_affinity"] = (
        features_df["lifetime_value"] * features_df["avg_rating"] / 1000
    )
    
    # Price sensitivity score
    result["price_sensitivity_score"] = (
        1 - (features_df["avg_order_value"] / features_df["price"])
    ).clip(0, 1)
    
    # Cross-sell recommendation
    result["cross_sell_recommendation"] = features_df.apply(
        lambda row: "premium" if row["customer_segment"] == "VIP" else "standard",
        axis=1
    )
    
    return result

Spark Feature Engineering Pipeline

# features/spark_engineering.py
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from datetime import datetime, timedelta
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SparkFeatureEngineer:
    def __init__(self, spark: SparkSession, config: dict):
        self.spark = spark
        self.config = config
        self.window_specs = self._define_window_specifications()
    
    def _define_window_specifications(self) -> dict:
        """Define common window specifications."""
        return {
            'daily': Window.partitionBy('customer_id').orderBy(
                F.col('event_date').cast('long')
            ).rangeBetween(-86400, 0),
            
            'weekly': Window.partitionBy('customer_id').orderBy(
                F.col('event_date').cast('long')
            ).rangeBetween(-604800, 0),
            
            'monthly': Window.partitionBy('customer_id').orderBy(
                F.col('event_date').cast('long')
            ).rangeBetween(-2592000, 0),
            
            'product_daily': Window.partitionBy('product_id').orderBy(
                F.col('event_date').cast('long')
            ).rangeBetween(-86400, 0)
        }
    
    def compute_customer_features(self, transactions_df, clickstream_df):
        """Compute comprehensive customer features."""
        logger.info("Computing customer features...")
        
        # Transaction-based features
        transaction_features = transactions_df \
            .withColumn('event_date', F.to_date('order_date')) \
            .groupBy('customer_id', 'event_date') \
            .agg(
                F.sum('amount').alias('daily_spend'),
                F.count('order_id').alias('daily_orders'),
                F.countDistinct('product_id').alias('daily_unique_products')
            )
        
        # Rolling window features
        rolling_features = transaction_features \
            .withColumn('spend_last_7d', F.sum('daily_spend').over(
                Window.partitionBy('customer_id').orderBy('event_date')
                .rangeBetween(-604800, 0)
            )) \
            .withColumn('spend_last_30d', F.sum('daily_spend').over(
                Window.partitionBy('customer_id').orderBy('event_date')
                .rangeBetween(-2592000, 0)
            )) \
            .withColumn('orders_last_7d', F.sum('daily_orders').over(
                Window.partitionBy('customer_id').orderBy('event_date')
                .rangeBetween(-604800, 0)
            )) \
            .withColumn('avg_order_value_7d', F.col('spend_last_7d') / F.col('orders_last_7d'))
        
        # Lifetime features
        lifetime_features = transactions_df \
            .groupBy('customer_id') \
            .agg(
                F.sum('amount').alias('lifetime_value'),
                F.count('order_id').alias('total_orders'),
                F.min('order_date').alias('first_order_date'),
                F.max('order_date').alias('last_order_date'),
                F.countDistinct(F.to_date('order_date')).alias('active_days')
            ) \
            .withColumn('days_since_last_order', 
                       F.datediff(F.current_date(), F.col('last_order_date'))) \
            .withColumn('customer_tenure_days',
                       F.datediff(F.current_date(), F.col('first_order_date'))) \
            .withColumn('order_frequency',
                       F.col('total_orders') / (F.col('customer_tenure_days') / 30))
        
        # Customer segmentation
        segmented_customers = lifetime_features \
            .withColumn('customer_segment',
                F.when(F.col('lifetime_value') > 10000, 'VIP')
                .when(F.col('lifetime_value') > 5000, 'Premium')
                .when(F.col('lifetime_value') > 1000, 'Regular')
                .otherwise('New')
            )
        
        # Clickstream features
        clickstream_features = clickstream_df \
            .withColumn('event_date', F.to_date('event_timestamp')) \
            .groupBy('user_id', 'event_date') \
            .agg(
                F.count('*').alias('daily_events'),
                F.countDistinct('session_id').alias('daily_sessions'),
                F.count(F.when(F.col('event_type') == 'page_view', True)).alias('page_views'),
                F.count(F.when(F.col('event_type') == 'add_to_cart', True)).alias('cart_adds'),
                F.count(F.when(F.col('event_type') == 'purchase', True)).alias('purchases')
            ) \
            .withColumn('conversion_rate', 
                       F.col('purchases') / F.col('page_views')) \
            .withColumn('cart_abandonment_rate',
                       1 - (F.col('purchases') / F.col('cart_adds')))
        
        # Combine all features
        final_features = segmented_customers \
            .join(rolling_features.select('customer_id', 'event_date',
                                         'spend_last_7d', 'spend_last_30d',
                                         'orders_last_7d', 'avg_order_value_7d'),
                  'customer_id', 'left') \
            .join(clickstream_features.select('user_id', 'event_date',
                                            'daily_events', 'daily_sessions',
                                            'page_views', 'cart_adds'),
                  segmented_customers['customer_id'] == clickstream_features['user_id'],
                  'left')
        
        return final_features
    
    def compute_product_features(self, products_df, sales_df, reviews_df):
        """Compute comprehensive product features."""
        logger.info("Computing product features...")
        
        # Sales-based features
        sales_features = sales_df \
            .withColumn('event_date', F.to_date('order_date')) \
            .groupBy('product_id', 'event_date') \
            .agg(
                F.sum('quantity').alias('daily_sales'),
                F.sum('amount').alias('daily_revenue'),
                F.countDistinct('order_id').alias('daily_orders')
            )
        
        # Rolling sales features
        rolling_sales = sales_features \
            .withColumn('sales_last_7d', F.sum('daily_sales').over(
                Window.partitionBy('product_id').orderBy('event_date')
                .rangeBetween(-604800, 0)
            )) \
            .withColumn('sales_last_30d', F.sum('daily_sales').over(
                Window.partitionBy('product_id').orderBy('event_date')
                .rangeBetween(-2592000, 0)
            )) \
            .withColumn('revenue_last_30d', F.sum('daily_revenue').over(
                Window.partitionBy('product_id').orderBy('event_date')
                .rangeBetween(-2592000, 0)
            ))
        
        # Review features
        review_features = reviews_df \
            .groupBy('product_id') \
            .agg(
                F.avg('rating').alias('avg_rating'),
                F.count('*').alias('review_count'),
                F.count(F.when(F.col('rating') >= 4, True)).alias('positive_reviews'),
                F.count(F.when(F.col('rating') <= 2, True)).alias('negative_reviews')
            ) \
            .withColumn('positive_review_ratio',
                       F.col('positive_reviews') / F.col('review_count'))
        
        # Product popularity features
        popularity_features = sales_df \
            .withColumn('event_date', F.to_date('order_date')) \
            .groupBy('product_id', 'event_date') \
            .agg(
                F.countDistinct('customer_id').alias('unique_buyers'),
                F.count('*').alias('total_purchases')
            ) \
            .withColumn('popularity_score',
                       F.col('unique_buyers') * F.col('total_purchases'))
        
        # Combine product features
        final_product_features = products_df \
            .join(rolling_sales.select('product_id', 'event_date',
                                      'sales_last_7d', 'sales_last_30d',
                                      'revenue_last_30d'),
                  'product_id', 'left') \
            .join(review_features, 'product_id', 'left') \
            .join(popularity_features.select('product_id', 'event_date',
                                            'popularity_score'),
                  'product_id', 'left')
        
        return final_product_features
    
    def compute_realtime_features(self, clickstream_batch_df, 
                                 customer_history_df):
        """Compute real-time features from streaming data."""
        logger.info("Computing real-time features...")
        
        # Current session features
        session_features = clickstream_batch_df \
            .groupBy('user_id', 'session_id') \
            .agg(
                F.count('*').alias('session_events'),
                F.countDistinct('page_url').alias('unique_pages'),
                F.min('event_timestamp').alias('session_start'),
                F.max('event_timestamp').alias('session_end'),
                F.count(F.when(F.col('event_type') == 'add_to_cart', True))
                    .alias('cart_additions')
            ) \
            .withColumn('session_duration_seconds',
                       F.unix_timestamp('session_end') - F.unix_timestamp('session_start'))
        
        # Real-time aggregates (last 5 minutes)
        realtime_agg = clickstream_batch_df \
            .withWatermark('event_timestamp', '2 minutes') \
            .groupBy(
                F.window('event_timestamp', '5 minutes'),
                'user_id'
            ) \
            .agg(
                F.count('*').alias('events_last_5m'),
                F.count(F.when(F.col('event_type') == 'page_view', True))
                    .alias('page_views_last_5m'),
                F.count(F.when(F.col('event_type') == 'add_to_cart', True))
                    .alias('cart_adds_last_5m')
            )
        
        # Merge with historical features
        realtime_features = realtime_agg \
            .join(customer_history_df, 'user_id', 'left') \
            .withColumn('engagement_score',
                F.col('events_last_5m') * 0.3 +
                F.col('page_views_last_5m') * 0.4 +
                F.col('cart_adds_last_5m') * 0.3
            )
        
        return realtime_features
    
    def save_features_to_store(self, features_df, feature_name: str,
                              output_path: str):
        """Save computed features to the feature store."""
        logger.info(f"Saving {feature_name} features to {output_path}")
        
        # Add metadata columns
        features_with_metadata = features_df \
            .withColumn('event_timestamp', F.current_timestamp()) \
            .withColumn('created_timestamp', F.current_timestamp()) \
            .withColumn('feature_version', F.lit('v1.0'))
        
        # Write to data lake in Parquet format
        features_with_metadata.write \
            .mode('overwrite') \
            .partitionBy('event_date') \
            .parquet(output_path)
        
        # Register with Feast (would call Feast API)
        logger.info(f"Registered {feature_name} with Feast feature store")
    
    def run_full_feature_pipeline(self):
        """Run the complete feature engineering pipeline."""
        logger.info("Starting full feature engineering pipeline")
        
        # Read source data
        transactions_df = self.spark.read.format('jdbc') \
            .option('url', self.config['db_url']) \
            .option('dbtable', 'public.orders') \
            .load()
        
        clickstream_df = self.spark.read.parquet(
            self.config['clickstream_path']
        )
        
        products_df = self.spark.read.format('jdbc') \
            .option('url', self.config['db_url']) \
            .option('dbtable', 'public.products') \
            .load()
        
        sales_df = self.spark.read.format('jdbc') \
            .option('url', self.config['db_url']) \
            .option('dbtable', 'public.order_items') \
            .load()
        
        reviews_df = self.spark.read.parquet(
            self.config['reviews_path']
        )
        
        # Compute features
        customer_features = self.compute_customer_features(
            transactions_df, clickstream_df
        )
        product_features = self.compute_product_features(
            products_df, sales_df, reviews_df
        )
        
        # Save features
        self.save_features_to_store(
            customer_features,
            'customer_features',
            f"{self.config['feature_store_path']}/customer_features"
        )
        
        self.save_features_to_store(
            product_features,
            'product_features',
            f"{self.config['feature_store_path']}/product_features"
        )
        
        logger.info("Feature engineering pipeline completed")

Step-by-Step Implementation Guide

Step 1: Feast Configuration

# feast/feature_store.yaml
project: ecommerce_feature_store
registry: s3://feature-store/registry/registry.db
provider: aws
online_store:
  type: redis
  connection_string: "redis-cluster.internal:6379,redis-cluster.internal:6380,redis-cluster.internal:6381"
  ssl: true
  
offline_store:
  type: snowflake
  account: ${SNOWFLAKE_ACCOUNT}
  user: ${SNOWFLAKE_USER}
  password: ${SNOWFLAKE_PASSWORD}
  role: FEATURE_STORE_ROLE
  warehouse: FEATURE_STORE_WH
  database: ECOMMERCE
  schema: FEATURES
  
entity_key_serialization_version: 2

feature_server:
  host: 0.0.0.0
  port: 6566
  enable_prometheus: true
  
logging:
  level: INFO
  path: /var/log/feast

Step 2: Feature Serving API

# serving/feature_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import feast
import redis
import json
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Feature Store API",
    description="Real-time feature serving for ML models",
    version="1.0.0"
)

class FeatureRequest(BaseModel):
    entity_rows: List[Dict[str, Any]]
    feature_refs: List[str]
    full_feature_names: bool = False

class FeatureResponse(BaseModel):
    metadata: Dict[str, str]
    results: List[Dict[str, Any]]

class FeatureStoreService:
    def __init__(self):
        self.store = feast.FeatureStore(repo_path="/opt/feast/feature_store.yaml")
        self.redis_client = redis.Redis(
            host='redis-cluster.internal',
            port=6379,
            decode_responses=True,
            ssl=True
        )
        self.cache_ttl = 300  # 5 minutes
    
    def get_online_features(self, entity_rows: List[Dict], 
                           feature_refs: List[str]) -> Dict:
        """Get features from online store with caching."""
        # Check cache first
        cache_key = self._generate_cache_key(entity_rows, feature_refs)
        cached = self.redis_client.get(cache_key)
        
        if cached:
            logger.info("Returning cached features")
            return json.loads(cached)
        
        # Get from Feast online store
        response = self.store.get_online_features(
            features=feature_refs,
            entity_rows=entity_rows,
            full_feature_names=True
        )
        
        result = {
            'metadata': {
                'feature_refs': feature_refs,
                'entity_count': len(entity_rows),
                'timestamp': datetime.utcnow().isoformat()
            },
            'results': response.to_dict()
        }
        
        # Cache the result
        self.redis_client.setex(
            cache_key,
            self.cache_ttl,
            json.dumps(result)
        )
        
        return result
    
    def get_historical_features(self, entity_df: pd.DataFrame,
                               feature_refs: List[str]) -> pd.DataFrame:
        """Get historical features for training."""
        return self.store.get_historical_features(
            entity_df=entity_df,
            features=feature_refs
        ).to_df()
    
    def _generate_cache_key(self, entity_rows: List[Dict],
                           feature_refs: List[str]) -> str:
        """Generate cache key for feature request."""
        import hashlib
        
        key_data = {
            'entities': sorted(str(row) for row in entity_rows),
            'features': sorted(feature_refs)
        }
        
        key_hash = hashlib.md5(
            json.dumps(key_data, sort_keys=True).encode()
        ).hexdigest()
        
        return f"features:{key_hash}"

# Initialize service
feature_service = FeatureStoreService()

@app.post("/get-features", response_model=FeatureResponse)
async def get_features(request: FeatureRequest):
    """Get real-time features for online serving."""
    try:
        result = feature_service.get_online_features(
            entity_rows=request.entity_rows,
            feature_refs=request.feature_refs
        )
        return FeatureResponse(**result)
    except Exception as e:
        logger.error(f"Feature retrieval failed: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/feature-views")
async def list_feature_views():
    """List all available feature views."""
    feature_views = feature_service.store.list_feature_views()
    return {
        'feature_views': [
            {
                'name': fv.name,
                'entities': fv.entities,
                'ttl': str(fv.ttl),
                'features': [f.name for f in fv.features]
            }
            for fv in feature_views
        ]
    }

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        'status': 'healthy',
        'timestamp': datetime.utcnow().isoformat(),
        'redis_connected': feature_service.redis_client.ping()
    }

@app.get("/metrics")
async def get_metrics():
    """Get feature store metrics."""
    # Prometheus metrics would be exposed here
    return {
        'requests_total': 0,
        'cache_hit_rate': 0.0,
        'avg_latency_ms': 0.0
    }

Step 3: Feature Monitoring

# monitoring/feature_monitor.py
import pandas as pd
import numpy as np
from scipy import stats
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import logging

logging.basicConfig(level.INFO)
logger = logging.getLogger(__name__)

class FeatureMonitor:
    def __init__(self, feature_store, alert_callback=None):
        self.feature_store = feature_store
        self.alert_callback = alert_callback
        self.baseline_stats = {}
    
    def compute_feature_statistics(self, features_df: pd.DataFrame) -> Dict:
        """Compute comprehensive statistics for features."""
        stats = {}
        
        for column in features_df.select_dtypes(include=[np.number]).columns:
            col_stats = {
                'mean': float(features_df[column].mean()),
                'std': float(features_df[column].std()),
                'min': float(features_df[column].min()),
                'max': float(features_df[column].max()),
                'median': float(features_df[column].median()),
                'null_count': int(features_df[column].isnull().sum()),
                'null_ratio': float(features_df[column].isnull().mean()),
                'skewness': float(features_df[column].skew()),
                'kurtosis': float(features_df[column].kurtosis()),
                'percentile_25': float(features_df[column].quantile(0.25)),
                'percentile_75': float(features_df[column].quantile(0.75)),
                'iqr': float(features_df[column].quantile(0.75) - features_df[column].quantile(0.25))
            }
            
            # Detect outliers using IQR method
            q1 = features_df[column].quantile(0.25)
            q3 = features_df[column].quantile(0.75)
            iqr = q3 - q1
            outlier_count = int(((features_df[column] < q1 - 1.5 * iqr) | 
                               (features_df[column] > q3 + 1.5 * iqr)).sum())
            col_stats['outlier_count'] = outlier_count
            col_stats['outlier_ratio'] = outlier_count / len(features_df)
            
            stats[column] = col_stats
        
        return stats
    
    def detect_data_drift(self, current_stats: Dict, 
                         baseline_stats: Dict,
                         threshold: float = 0.1) -> List[Dict]:
        """Detect data drift using statistical tests."""
        drift_alerts = []
        
        for feature_name, current in current_stats.items():
            if feature_name not in baseline_stats:
                continue
            
            baseline = baseline_stats[feature_name]
            
            # Kolmogorov-Smirnov test for distribution drift
            ks_stat, p_value = stats.ks_2samp(
                [current['mean']],
                [baseline['mean']]
            )
            
            # Population Stability Index (PSI)
            psi = self._calculate_psi(
                baseline['mean'], 
                current['mean'],
                baseline['std']
            )
            
            if p_value < 0.05 or psi > threshold:
                drift_alerts.append({
                    'feature': feature_name,
                    'drift_type': 'distribution',
                    'p_value': p_value,
                    'psi': psi,
                    'baseline_mean': baseline['mean'],
                    'current_mean': current['mean'],
                    'severity': 'high' if psi > threshold * 2 else 'medium'
                })
            
            # Null ratio drift
            null_ratio_change = abs(current['null_ratio'] - baseline['null_ratio'])
            if null_ratio_change > 0.05:
                drift_alerts.append({
                    'feature': feature_name,
                    'drift_type': 'null_ratio',
                    'baseline_null_ratio': baseline['null_ratio'],
                    'current_null_ratio': current['null_ratio'],
                    'severity': 'medium'
                })
        
        return drift_alerts
    
    def _calculate_psi(self, baseline_mean: float, current_mean: float,
                      baseline_std: float) -> float:
        """Calculate Population Stability Index."""
        # Simplified PSI calculation
        if baseline_std == 0:
            return 0
        
        psi = abs(current_mean - baseline_mean) / baseline_std
        return psi
    
    def validate_feature_freshness(self, feature_view: str,
                                  max_age_seconds: int = 300) -> Dict:
        """Validate feature freshness and timeliness."""
        # Get latest feature timestamp
        latest_features = self.feature_store.get_online_features(
            features=[f"{feature_view}:event_timestamp"],
            entity_rows=[{'customer_id': 'latest_check'}]
        )
        
        if latest_features:
            latest_timestamp = datetime.fromisoformat(
                latest_features['metadata']['timestamp']
            )
            age_seconds = (datetime.utcnow() - latest_timestamp).total_seconds()
            
            is_fresh = age_seconds <= max_age_seconds
            
            return {
                'feature_view': feature_view,
                'latest_timestamp': latest_timestamp.isoformat(),
                'age_seconds': age_seconds,
                'is_fresh': is_fresh,
                'max_age_seconds': max_age_seconds
            }
        
        return {
            'feature_view': feature_view,
            'is_fresh': False,
            'error': 'No features found'
        }
    
    def monitor_feature_quality(self, features_df: pd.DataFrame) -> Dict:
        """Comprehensive feature quality monitoring."""
        quality_report = {
            'timestamp': datetime.utcnow().isoformat(),
            'total_features': len(features_df.columns),
            'total_rows': len(features_df),
            'issues': []
        }
        
        for column in features_df.columns:
            feature_issues = []
            
            # Check for high null ratio
            null_ratio = features_df[column].isnull().mean()
            if null_ratio > 0.1:
                feature_issues.append({
                    'type': 'high_null_ratio',
                    'value': float(null_ratio),
                    'threshold': 0.1
                })
            
            # Check for constant features
            if features_df[column].nunique() == 1:
                feature_issues.append({
                    'type': 'constant_feature',
                    'unique_values': 1
                })
            
            # Check for high cardinality in categorical features
            if features_df[column].dtype == 'object':
                cardinality = features_df[column].nunique()
                if cardinality > 100:
                    feature_issues.append({
                        'type': 'high_cardinality',
                        'cardinality': cardinality,
                        'threshold': 100
                    })
            
            # Check for numeric outliers
            if pd.api.types.is_numeric_dtype(features_df[column]):
                q1 = features_df[column].quantile(0.25)
                q3 = features_df[column].quantile(0.75)
                iqr = q3 - q1
                outlier_ratio = ((features_df[column] < q1 - 3 * iqr) | 
                                (features_df[column] > q3 + 3 * iqr)).mean()
                
                if outlier_ratio > 0.01:
                    feature_issues.append({
                        'type': 'extreme_outliers',
                        'ratio': float(outlier_ratio),
                        'threshold': 0.01
                    })
            
            if feature_issues:
                quality_report['issues'].append({
                    'feature': column,
                    'issues': feature_issues
                })
        
        quality_report['features_with_issues'] = len(quality_report['issues'])
        quality_report['quality_score'] = (
            1 - (quality_report['features_with_issues'] / 
                 quality_report['total_features'])
        ) * 100
        
        return quality_report
    
    def generate_monitoring_report(self) -> Dict:
        """Generate comprehensive monitoring report."""
        report = {
            'timestamp': datetime.utcnow().isoformat(),
            'feature_views': {}
        }
        
        # Monitor each feature view
        feature_views = self.feature_store.list_feature_views()
        
        for fv in feature_views:
            logger.info(f"Monitoring feature view: {fv.name}")
            
            # Get sample data
            sample_entities = [{'customer_id': f'customer_{i}'} 
                              for i in range(1000)]
            
            features = self.feature_store.get_online_features(
                features=[f"{fv.name}:{f.name}" for f in fv.features],
                entity_rows=sample_entities
            )
            
            if features:
                features_df = pd.DataFrame(features['results'])
                
                # Compute statistics
                stats = self.compute_feature_statistics(features_df)
                
                # Check freshness
                freshness = self.validate_feature_freshness(fv.name)
                
                # Check quality
                quality = self.monitor_feature_quality(features_df)
                
                report['feature_views'][fv.name] = {
                    'statistics': stats,
                    'freshness': freshness,
                    'quality': quality
                }
        
        # Check for drift if baseline exists
        if self.baseline_stats:
            current_stats = {}
            for fv_name, fv_report in report['feature_views'].items():
                current_stats[fv_name] = fv_report['statistics']
            
            drift_alerts = self.detect_data_drift(
                current_stats, 
                self.baseline_stats
            )
            report['drift_alerts'] = drift_alerts
        
        # Store baseline for future comparison
        self.baseline_stats = {
            fv_name: fv_report['statistics']
            for fv_name, fv_report in report['feature_views'].items()
        }
        
        return report

Infrastructure Setup (Terraform)

# infrastructure/ml_feature_store.tf
terraform {
  required_version = ">= 1.5.0"
  
  required_providers {
    aws = {
      source  = "hashicorp/aws"
      version = "~> 5.0"
    }
  }
}

# Redis Cluster for Online Store
resource "aws_elasticache_replication_group" "feature_redis" {
  replication_group_id       = "ml-feature-store-redis"
  description                = "Redis cluster for ML feature store online serving"
  
  node_type            = "cache.r6g.xlarge"
  num_cache_clusters   = 6
  
  engine               = "redis"
  engine_version       = "7.0"
  port                 = 6379
  
  subnet_group_name    = aws_elasticache_subnet_group.redis.name
  security_group_ids   = [aws_security_group.redis.id]
  
  at_rest_encryption_enabled = true
  transit_encryption_enabled = true
  
  automatic_failover_enabled = true
  multi_az_enabled          = true
  
  snapshot_retention_limit = 7
  snapshot_window         = "03:00-04:00"
  
  tags = {
    Project = "ml-feature-store"
    Purpose = "online-feature-serving"
  }
}

# ElastiCache Subnet Group
resource "aws_elasticache_subnet_group" "redis" {
  name       = "ml-feature-redis-subnet"
  subnet_ids = module.vpc.private_subnets
}

# Security Group for Redis
resource "aws_security_group" "redis" {
  name_prefix = "ml-feature-redis-"
  vpc_id      = module.vpc.vpc_id
  
  ingress {
    from_port       = 6379
    to_port         = 6379
    protocol        = "tcp"
    cidr_blocks     = [module.vpc.vpc_cidr_block]
    description     = "Redis"
  }
  
  egress {
    from_port   = 0
    to_port     = 0
    protocol    = "-1"
    cidr_blocks = ["0.0.0.0/0"]
  }
}

# S3 Bucket for Feature Store
resource "aws_s3_bucket" "feature_store" {
  bucket = "ml-feature-store-${var.environment}"
  
  tags = {
    Project = "ml-feature-store"
    Purpose = "feature-storage"
  }
}

resource "aws_s3_bucket_versioning" "feature_store" {
  bucket = aws_s3_bucket.feature_store.id
  
  versioning_configuration {
    status = "Enabled"
  }
}

resource "aws_s3_bucket_server_side_encryption_configuration" "feature_store" {
  bucket = aws_s3_bucket.feature_store.id
  
  rule {
    apply_server_side_encryption_by_default {
      sse_algorithm = "aws:kms"
    }
  }
}

# EMR Cluster for Feature Engineering
resource "aws_emr_cluster" "feature_engineering" {
  name          = "ml-feature-engineering"
  release_label = "emr-6.15.0"
  applications  = ["Spark", "Hive"]
  
  master_instance_group {
    instance_type  = "m5.xlarge"
    instance_count = 1
  }
  
  core_instance_group {
    instance_type  = "r5.2xlarge"
    instance_count = 3
    
    ebs_config {
      type                 = "gp3"
      size                 = 100
      volumes_per_instance = 1
    }
  }
  
  ec2_attributes {
    instance_profile = aws_iam_instance_profile.emr.arn
    subnet_id        = module.vpc.private_subnets[0]
  }
  
  service_role = aws_iam_role.emr_service.arn
  
  tags = {
    Project = "ml-feature-store"
    Purpose = "feature-engineering"
  }
}

# ECS Cluster for Feature Server
resource "aws_ecs_cluster" "feature_server" {
  name = "ml-feature-server"
  
  setting {
    name  = "containerInsights"
    value = "enabled"
  }
  
  tags = {
    Project = "ml-feature-store"
    Purpose = "feature-serving"
  }
}

# ECS Task Definition for Feature Server
resource "aws_ecs_task_definition" "feature_server" {
  family                   = "feature-server"
  network_mode             = "awsvpc"
  requires_compatibilities = ["FARGATE"]
  cpu                      = "2048"
  memory                   = "4096"
  
  container_definitions = jsonencode([
    {
      name  = "feature-server"
      image = "${var.aws_account_id}.dkr.ecr.${var.aws_region}.amazonaws.com/feature-server:latest"
      
      portMappings = [
        {
          containerPort = 6566
          hostPort      = 6566
          protocol      = "tcp"
        }
      ]
      
      environment = [
        {
          name  = "FEAST_CONFIG_PATH"
          value = "/opt/feast/feature_store.yaml"
        },
        {
          name  = "REDIS_HOST"
          value = aws_elasticache_replication_group.feature_redis.primary_endpoint_address
        }
      ]
      
      logConfiguration = {
        logDriver = "awslogs"
        options = {
          "awslogs-group"         = aws_cloudwatch_log_group.feature_server.name
          "awslogs-region"        = var.aws_region
          "awslogs-stream-prefix" = "feature-server"
        }
      }
    }
  ])
  
  execution_role_arn = aws_iam_role.ecs_execution.arn
  task_role_arn      = aws_iam_role.ecs_task.arn
}

# ECS Service for Feature Server
resource "aws_ecs_service" "feature_server" {
  name            = "feature-server"
  cluster         = aws_ecs_cluster.feature_server.id
  task_definition = aws_ecs_task_definition.feature_server.arn
  desired_count   = 3
  
  network_configuration {
    subnets          = module.vpc.private_subnets
    security_groups  = [aws_security_group.feature_server.id]
    assign_public_ip = false
  }
  
  load_balancer {
    target_group_arn = aws_lb_target_group.feature_server.arn
    container_name   = "feature-server"
    container_port   = 6566
  }
  
  depends_on = [aws_lb_listener.feature_server]
}

# Application Load Balancer
resource "aws_lb" "feature_server" {
  name               = "feature-server-alb"
  internal           = true
  load_balancer_type = "application"
  security_groups    = [aws_security_group.alb.id]
  subnets            = module.vpc.private_subnets
  
  tags = {
    Project = "ml-feature-store"
  }
}

resource "aws_lb_target_group" "feature_server" {
  name        = "feature-server-tg"
  port        = 6566
  protocol    = "HTTP"
  vpc_id      = module.vpc.vpc_id
  target_type = "ip"
  
  health_check {
    path                = "/health"
    port                = "traffic-port"
    healthy_threshold   = 3
    unhealthy_threshold = 3
    timeout             = 5
    interval            = 30
    matcher             = "200"
  }
}

resource "aws_lb_listener" "feature_server" {
  load_balancer_arn = aws_lb.feature_server.arn
  port              = 80
  protocol          = "HTTP"
  
  default_action {
    type             = "forward"
    target_group_arn = aws_lb_target_group.feature_server.arn
  }
}

# CloudWatch Log Group
resource "aws_cloudwatch_log_group" "feature_server" {
  name              = "/ecs/feature-server"
  retention_in_days = 30
}

# IAM Roles
resource "aws_iam_role" "ecs_execution" {
  name = "ecs-feature-server-execution"
  
  assume_role_policy = jsonencode({
    Version = "2012-10-17"
    Statement = [
      {
        Action = "sts:AssumeRole"
        Effect = "Allow"
        Principal = {
          Service = "ecs-tasks.amazonaws.com"
        }
      }
    ]
  })
}

resource "aws_iam_role" "ecs_task" {
  name = "ecs-feature-server-task"
  
  assume_role_policy = jsonencode({
    Version = "2012-10-17"
    Statement = [
      {
        Action = "sts:AssumeRole"
        Effect = "Allow"
        Principal = {
          Service = "ecs-tasks.amazonaws.com"
        }
      }
    ]
  })
}

resource "aws_iam_role_policy" "ecs_task" {
  name = "ecs-feature-server-task-policy"
  role = aws_iam_role.ecs_task.id
  
  policy = jsonencode({
    Version = "2012-10-17"
    Statement = [
      {
        Effect = "Allow"
        Action = [
          "elasticache:*",
          "s3:GetObject",
          "s3:PutObject"
        ]
        Resource = "*"
      }
    ]
  })
}

# Outputs
output "redis_endpoint" {
  description = "Redis endpoint for online feature serving"
  value       = aws_elasticache_replication_group.feature_redis.primary_endpoint_address
}

output "feature_store_bucket" {
  description = "S3 bucket for feature store"
  value       = aws_s3_bucket.feature_store.bucket
}

output "feature_server_dns" {
  description = "Feature server DNS name"
  value       = aws_lb.feature_server.dns_name
}

Testing and Validation

# tests/test_feature_store.py
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock

class TestFeatureStore:
    @pytest.fixture
    def sample_features(self):
        return pd.DataFrame({
            'customer_id': [f'customer_{i}' for i in range(1000)],
            'lifetime_value': np.random.uniform(0, 10000, 1000),
            'total_orders': np.random.randint(1, 50, 1000),
            'avg_order_value': np.random.uniform(10, 500, 1000),
            'days_since_last_order': np.random.randint(0, 365, 1000),
            'customer_segment': np.random.choice(['VIP', 'Premium', 'Regular', 'New'], 1000)
        })
    
    @pytest.fixture
    def feature_monitor(self):
        from monitoring.feature_monitor import FeatureMonitor
        mock_store = Mock()
        return FeatureMonitor(mock_store)
    
    def test_feature_statistics(self, feature_monitor, sample_features):
        """Test feature statistics computation."""
        stats = feature_monitor.compute_feature_statistics(sample_features)
        
        assert 'lifetime_value' in stats
        assert 'total_orders' in stats
        
        # Check statistics are computed
        assert 'mean' in stats['lifetime_value']
        assert 'std' in stats['lifetime_value']
        assert 'null_count' in stats['lifetime_value']
    
    def test_data_drift_detection(self, feature_monitor):
        """Test data drift detection."""
        baseline_stats = {
            'feature_1': {
                'mean': 100.0,
                'std': 10.0,
                'null_ratio': 0.01
            }
        }
        
        current_stats = {
            'feature_1': {
                'mean': 120.0,
                'std': 15.0,
                'null_ratio': 0.05
            }
        }
        
        drift_alerts = feature_monitor.detect_data_drift(
            current_stats, baseline_stats
        )
        
        assert len(drift_alerts) > 0
        assert drift_alerts[0]['feature'] == 'feature_1'
    
    def test_feature_freshness(self, feature_monitor):
        """Test feature freshness validation."""
        with patch.object(feature_monitor.feature_store, 'get_online_features') as mock:
            mock.return_value = {
                'metadata': {
                    'timestamp': datetime.utcnow().isoformat()
                }
            }
            
            freshness = feature_monitor.validate_feature_freshness(
                'test_feature_view',
                max_age_seconds=300
            )
            
            assert freshness['is_fresh'] == True
    
    def test_feature_quality(self, feature_monitor):
        """Test feature quality monitoring."""
        # Create features with quality issues
        problematic_features = pd.DataFrame({
            'good_feature': np.random.uniform(0, 1, 1000),
            'high_null_feature': [np.nan] * 500 + np.random.uniform(0, 1, 500),
            'constant_feature': [1.0] * 1000,
            'high_cardinality': [f'value_{i}' for i in range(1000)]
        })
        
        quality_report = feature_monitor.monitor_feature_quality(
            problematic_features
        )
        
        assert quality_report['features_with_issues'] > 0
        assert quality_report['quality_score'] < 100
    
    def test_online_feature_serving(self):
        """Test online feature serving API."""
        from serving.feature_server import FeatureStoreService
        
        mock_store = Mock()
        mock_store.get_online_features.return_value = Mock(
            to_dict=lambda: {
                'results': [
                    {'customer_id': 'customer_1', 'lifetime_value': 5000.0}
                ]
            }
        )
        
        service = FeatureStoreService()
        service.store = mock_store
        
        result = service.get_online_features(
            entity_rows=[{'customer_id': 'customer_1'}],
            feature_refs=['customer_features:lifetime_value']
        )
        
        assert 'metadata' in result
        assert 'results' in result
    
    def test_feature_cache(self):
        """Test feature caching in Redis."""
        import json
        from unittest.mock import Mock
        
        mock_redis = Mock()
        mock_redis.get.return_value = None
        mock_redis.setex.return_value = True
        
        cache_key = "features:test_hash"
        cache_ttl = 300
        
        # Simulate cache miss
        cached = mock_redis.get(cache_key)
        assert cached is None
        
        # Simulate cache set
        feature_data = {'test': 'data'}
        mock_redis.setex(cache_key, cache_ttl, json.dumps(feature_data))
        
        mock_redis.setex.assert_called_once_with(
            cache_key, cache_ttl, json.dumps(feature_data)
        )
    
    def test_historical_features(self):
        """Test historical feature retrieval for training."""
        from datetime import datetime
        
        # Create entity dataframe
        entity_df = pd.DataFrame({
            'customer_id': [f'customer_{i}' for i in range(100)],
            'event_date': [datetime(2024, 1, 15)] * 100
        })
        
        # Simulate historical features
        historical_features = pd.DataFrame({
            'customer_id': [f'customer_{i}' for i in range(100)],
            'event_date': [datetime(2024, 1, 15)] * 100,
            'lifetime_value': np.random.uniform(0, 10000, 100),
            'total_orders': np.random.randint(1, 50, 100)
        })
        
        # Verify schema
        assert 'customer_id' in historical_features.columns
        assert 'lifetime_value' in historical_features.columns
        assert len(historical_features) == len(entity_df)
    
    def test_feature_pipeline(self):
        """Test end-to-end feature pipeline."""
        from pyspark.sql import SparkSession
        
        spark = SparkSession.builder \
            .appName("TestFeaturePipeline") \
            .master("local[*]") \
            .getOrCreate()
        
        # Create sample transaction data
        transactions_data = [
            ('customer_1', 'order_1', 100.0, '2024-01-15'),
            ('customer_1', 'order_2', 150.0, '2024-01-16'),
            ('customer_2', 'order_3', 200.0, '2024-01-15')
        ]
        
        transactions_df = spark.createDataFrame(
            transactions_data,
            ['customer_id', 'order_id', 'amount', 'order_date']
        )
        
        # Compute features
        from features.spark_engineering import SparkFeatureEngineer
        
        config = {'db_url': 'jdbc:postgresql://localhost/test'}
        engineer = SparkFeatureEngineer(spark, config)
        
        # Simple aggregation test
        customer_features = transactions_df \
            .groupBy('customer_id') \
            .agg(
                {'amount': 'sum', 'order_id': 'count'}
            )
        
        assert 'customer_id' in customer_features.columns
        assert 'sum(amount)' in customer_features.columns
        
        spark.stop()
    
    def test_feature_validation(self):
        """Test feature validation rules."""
        validation_rules = [
            {
                'feature': 'lifetime_value',
                'type': 'range',
                'min': 0,
                'max': 100000
            },
            {
                'feature': 'customer_segment',
                'type': 'enum',
                'values': ['VIP', 'Premium', 'Regular', 'New']
            },
            {
                'feature': 'total_orders',
                'type': 'non_negative'
            }
        ]
        
        # Test data
        test_data = {
            'lifetime_value': 5000,
            'customer_segment': 'VIP',
            'total_orders': 10
        }
        
        # Validate
        for rule in validation_rules:
            feature = rule['feature']
            value = test_data[feature]
            
            if rule['type'] == 'range':
                assert rule['min'] <= value <= rule['max']
            elif rule['type'] == 'enum':
                assert value in rule['values']
            elif rule['type'] == 'non_negative':
                assert value >= 0
    
    def test_feature_monitoring_report(self, feature_monitor, sample_features):
        """Test comprehensive monitoring report generation."""
        with patch.object(feature_monitor.feature_store, 'list_feature_views') as mock_views:
            mock_fv = Mock()
            mock_fv.name = 'customer_features'
            mock_fv.features = [Mock(name='lifetime_value'), Mock(name='total_orders')]
            mock_views.return_value = [mock_fv]
            
            with patch.object(feature_monitor.feature_store, 'get_online_features') as mock_features:
                mock_features.return_value = {
                    'results': sample_features.to_dict('records')
                }
                
                report = feature_monitor.generate_monitoring_report()
                
                assert 'feature_views' in report
                assert 'customer_features' in report['feature_views']

Cost Analysis

Monthly Cost Breakdown (Production)

ComponentSpecificationMonthly Cost
Redis Cluster6x cache.r6g.xlarge$2,400
EMR Cluster3x r5.2xlarge (intermittent)$800
ECS (Feature Server)3x Fargate tasks$300
S3 Storage2TB feature data$46
SnowflakeFeature store queries$200
Data TransferCross-AZ$100
Total$3,846

Cost Optimization Strategies

💡

Tip: Optimize feature store costs with these strategies:

  1. Feature Caching: 90% cache hit rate reduces compute
  2. Feature Pruning: Remove low-importance features
  3. Batch vs Real-time: Use batch for non-critical features
  4. Spot Instances: Use for EMR feature engineering
  5. Auto-scaling: Scale Redis based on traffic

Performance Metrics

MetricBefore Feature StoreAfter Feature StoreImprovement
Feature Development2 weeks2 days7x faster
Model Deployment1 week1 hour168x faster
Feature Serving Latency50ms2ms25x faster
Training-Serving Skew15%< 1%15x better

Interview Talking Points

Architecture Decisions

ℹ️

Best Practice: Focus on these feature store concepts in interviews:

  1. Why Feast over custom feature store?

    • Open-source and cloud-agnostic
    • Native integration with Spark, Flink
    • Online/offline store separation
    • Feature versioning and lineage
  2. Why Redis for online store?

    • Sub-millisecond latency
    • Native data structure support
    • Cluster mode for scaling
    • Pub/Sub for real-time updates
  3. How to prevent training-serving skew?

    • Same feature definitions for both
    • Point-in-time correctness
    • Feature versioning and snapshots
    • Automated validation tests

Common Interview Questions

Q: "How do you handle feature freshness for real-time models?"

freshness_strategies = {
    "Streaming Features": "Flink/Spark Streaming for real-time computation",
    "Cache Invalidation": "TTL-based expiration with event-driven updates",
    "Hybrid Approach": "Batch for historical, streaming for recent",
    "Feature Monitoring": "Alert on staleness and freshness violations"
}

Q: "How do you ensure feature quality at scale?"

quality_framework = {
    "Validation Rules": "Schema validation, range checks, null handling",
    "Statistical Tests": "Distribution monitoring, drift detection",
    "Automated Testing": "CI/CD pipeline for feature definitions",
    "Monitoring Dashboards": "Real-time feature health metrics"
}

Q: "How do you version features for reproducibility?"

versioning_strategy = {
    "Feature Snapshots": "Point-in-time snapshots for training",
    "Git Integration": "Feature definitions in version control",
    "Registry": "Central metadata store for feature versions",
    "Lineage Tracking": "Track feature dependencies and transformations"
}

Deployment Checklist

  • Set up Redis cluster for online store
  • Configure Snowflake for offline store
  • Deploy Feast feature store
  • Set up Spark feature engineering pipeline
  • Deploy feature serving API
  • Configure feature monitoring
  • Set up CI/CD for feature definitions
  • Test online/offline consistency
  • Document feature catalog
  • Train ML team on feature store usage

⚠️

Warning: Always validate feature freshness and quality before serving to production models. Stale or low-quality features can significantly degrade model performance.


This project demonstrates ML infrastructure skills and is highly relevant for ML engineering and data science platform interviews.

Advertisement