πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Model Serving: Batch, Real-Time, Edge, A/B Testing, Shadow

MLOpsModel Serving⭐ Premium

Advertisement

Interview Question (Hard) β€” Asked at: Uber, Netflix, Spotify, Apple, Tesla

"Design a model serving architecture that supports batch, real-time, and edge inference. How do you implement A/B testing and shadow mode deployments while maintaining low latency?"

Model Serving Architecture Overview

Model serving is the process of deploying trained ML models to production for inference. The serving pattern depends on latency requirements, throughput, cost constraints, and deployment environment.

Serving Pattern Decision Matrix

PatternLatencyThroughputCostUse Case
Batch InferenceHoursVery HighLowReport generation, ETL
Real-Time REST10-100msMediumHighWeb APIs, mobile
Real-Time gRPC1-10msHighHighHigh-frequency trading
Edge Inference1-5msLowLowIoT, mobile offline
Streaming100ms-1sHighMediumEvent-driven systems

Batch Inference Patterns

PySpark Batch Inference

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, DoubleType
import mlflow
import mlflow.spark

class BatchInferencePipeline:
    def __init__(self, spark: SparkSession, model_uri: str):
        self.spark = spark
        self.model = mlflow.spark.load_model(model_uri)
    
    def run_batch_inference(self, input_path: str, 
                           output_path: str,
                           batch_size: int = 1000000):
        """Run batch inference on large dataset."""
        
        # Read input data
        input_df = self.spark.read.parquet(input_path)
        
        # Add batch identifier
        input_df = input_df.withColumn(
            "batch_id",
            F.date_format(F.current_timestamp(), "yyyyMMdd_HHmmss")
        )
        
        # Repartition for parallel processing
        num_partitions = max(1, input_df.count() // batch_size)
        input_df = input_df.repartition(num_partitions)
        
        # Run inference
        predictions = self.model.transform(input_df)
        
        # Add metadata
        predictions = predictions \
            .withColumn("inference_timestamp", F.current_timestamp()) \
            .withColumn("model_version", F.lit(self.model_version))
        
        # Write results with partitioning
        predictions.write \
            .mode("overwrite") \
            .partitionBy("batch_id") \
            .parquet(output_path)
        
        # Log inference statistics
        stats = {
            "total_rows": input_df.count(),
            "total_batches": num_partitions,
            "output_path": output_path,
            "timestamp": datetime.now().isoformat()
        }
        
        return stats
    
    def run_incremental_inference(self, input_path: str,
                                   output_path: str,
                                   watermark_col: str,
                                   interval_minutes: int = 60):
        """Run incremental batch inference."""
        
        # Read new data since last inference
        last_watermark = self._get_last_watermark(output_path)
        
        input_df = self.spark.read.parquet(input_path) \
            .filter(F.col(watermark_col) > last_watermark)
        
        if input_df.count() == 0:
            print("No new data to process")
            return
        
        # Run inference
        predictions = self.model.transform(input_df)
        
        # Append to output
        predictions.write \
            .mode("append") \
            .parquet(output_path)
        
        # Update watermark
        new_watermark = input_df.agg(
            F.max(watermark_col)
        ).collect()[0][0]
        self._update_watermark(output_path, new_watermark)
    
    def _get_last_watermark(self, output_path: str):
        """Get last processed watermark."""
        try:
            metadata = self.spark.read.parquet(
                f"{output_path}/_metadata/watermark"
            ).collect()[0]
            return metadata["watermark"]
        except:
            return "1970-01-01"
    
    def _update_watermark(self, output_path: str, watermark):
        """Update watermark metadata."""
        watermark_df = self.spark.createDataFrame(
            [(watermark, datetime.now().isoformat())],
            ["watermark", "updated_at"]
        )
        watermark_df.write \
            .mode("overwrite") \
            .parquet(f"{output_path}/_metadata/watermark")

Real-Time Serving with FastAPI

REST API Model Server

from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import uvicorn
import mlflow
import numpy as np
import pandas as pd
from datetime import datetime
import asyncio
import logging
from contextlib import asynccontextmanager
import redis.asyncio as redis
import json

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelServer:
    def __init__(self):
        self.model = None
        self.preprocessor = None
        self.model_version = None
        self.redis = None
        self.feature_cache_ttl = 300  # 5 minutes
    
    async def load_model(self, model_uri: str):
        """Load model and preprocessor."""
        self.model = mlflow.pyfunc.load_model(model_uri)
        self.model_version = model_uri.split("/")[-1]
        
        # Connect to Redis for caching
        self.redis = await redis.from_url(
            "redis://localhost:6379",
            encoding="utf-8",
            decode_responses=True
        )
        
        logger.info(f"Model loaded: {self.model_version}")
    
    async def get_cached_features(self, entity_id: str) -> Optional[Dict]:
        """Get features from cache."""
        if self.redis:
            cached = await self.redis.get(f"features:{entity_id}")
            if cached:
                return json.loads(cached)
        return None
    
    async def set_cached_features(self, entity_id: str, 
                                   features: Dict, ttl: int = None):
        """Cache features."""
        if self.redis:
            await self.redis.setex(
                f"features:{entity_id}",
                ttl or self.feature_cache_ttl,
                json.dumps(features)
            )

# Global model server
model_server = ModelServer()

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load model on startup."""
    await model_server.load_model("models:/fraud_detection/Production")
    yield
    # Cleanup on shutdown
    if model_server.redis:
        await model_server.redis.close()

app = FastAPI(
    title="ML Model Server",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class PredictionRequest(BaseModel):
    entity_id: str = Field(..., description="Unique entity identifier")
    features: Dict[str, float] = Field(
        ..., 
        description="Feature values for prediction"
    )
    request_id: Optional[str] = Field(
        None, 
        description="Optional request ID for tracking"
    )
    
    class Config:
        json_schema_extra = {
            "example": {
                "entity_id": "user_12345",
                "features": {
                    "transaction_amount": 150.00,
                    "time_since_last_transaction": 3600,
                    "merchant_category": 1,
                    "user_account_age_days": 365
                },
                "request_id": "req_abc123"
            }
        }

class PredictionResponse(BaseModel):
    prediction: float
    probability: float
    confidence: float
    model_version: str
    latency_ms: float
    request_id: Optional[str] = None
    timestamp: datetime

class BatchPredictionRequest(BaseModel):
    instances: List[PredictionRequest]

class BatchPredictionResponse(BaseModel):
    predictions: List[PredictionResponse]
    total_latency_ms: float
    batch_size: int

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    """Single prediction endpoint."""
    start_time = datetime.now()
    
    try:
        # Try to get cached features
        cached_features = await model_server.get_cached_features(
            request.entity_id
        )
        
        # Use provided features or cached
        features = request.features
        if cached_features:
            features.update(cached_features)
        
        # Prepare input
        input_df = pd.DataFrame([features])
        
        # Run prediction
        prediction = model_server.model.predict(input_df)
        
        # Calculate confidence
        probability = float(prediction[0]) if hasattr(prediction[0], '__float__') else float(prediction[0][1])
        confidence = max(probability, 1 - probability)
        
        # Calculate latency
        latency_ms = (datetime.now() - start_time).total_seconds() * 1000
        
        # Log prediction for monitoring
        logger.info(f"Prediction: {request.entity_id}, "
                   f"latency: {latency_ms:.2f}ms")
        
        return PredictionResponse(
            prediction=float(prediction[0]),
            probability=probability,
            confidence=confidence,
            model_version=model_server.model_version,
            latency_ms=latency_ms,
            request_id=request.request_id,
            timestamp=datetime.now()
        )
    
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
    """Batch prediction endpoint."""
    start_time = datetime.now()
    
    try:
        # Prepare batch input
        features_list = [inst.features for inst in request.instances]
        entity_ids = [inst.entity_id for inst in request.instances]
        
        input_df = pd.DataFrame(features_list)
        
        # Run batch prediction
        predictions = model_server.model.predict(input_df)
        
        # Build responses
        responses = []
        for i, (pred, entity_id) in enumerate(zip(predictions, entity_ids)):
            probability = float(pred) if hasattr(pred, '__float__') else float(pred[1])
            confidence = max(probability, 1 - probability)
            
            responses.append(PredictionResponse(
                prediction=float(pred),
                probability=probability,
                confidence=confidence,
                model_version=model_server.model_version,
                latency_ms=0,  # Will be updated
                request_id=request.instances[i].request_id,
                timestamp=datetime.now()
            ))
        
        total_latency_ms = (datetime.now() - start_time).total_seconds() * 1000
        
        # Update individual latencies proportionally
        per_item_latency = total_latency_ms / len(responses)
        for resp in responses:
            resp.latency_ms = per_item_latency
        
        return BatchPredictionResponse(
            predictions=responses,
            total_latency_ms=total_latency_ms,
            batch_size=len(responses)
        )
    
    except Exception as e:
        logger.error(f"Batch prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": model_server.model is not None,
        "model_version": model_server.model_version,
        "timestamp": datetime.now()
    }

@app.get("/metrics")
async def metrics():
    """Prometheus metrics endpoint."""
    return {
        "model_version": model_server.model_version,
        "requests_total": 0,  # Would be tracked
        "latency_p50": 0,
        "latency_p99": 0,
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

ℹ️

For sub-10ms latency, use gRPC instead of REST. NVIDIA Triton Inference Server provides optimized serving with dynamic batching, model ensembles, and multi-GPU support.

A/B Testing Framework

Statistical A/B Testing Implementation

import numpy as np
from scipy import stats
from typing import Dict, List, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
import hashlib
import json

@dataclass
class ABTestConfig:
    name: str
    model_a: str  # Control
    model_b: str  # Treatment
    traffic_split: float  # % to treatment
    min_samples: int = 1000
    significance_level: float = 0.05
    primary_metric: str = "conversion_rate"
    secondary_metrics: List[str] = None
    duration_days: int = 7

class ABTestManager:
    def __init__(self, config: ABTestConfig):
        self.config = config
        self.results_a = []
        self.results_b = []
        self.start_time = datetime.now()
    
    def assign_variant(self, entity_id: str) -> str:
        """Deterministically assign entity to variant."""
        hash_value = int(
            hashlib.md5(
                f"{self.config.name}:{entity_id}".encode()
            ).hexdigest(), 
            16
        ) % 100
        
        if hash_value < self.config.traffic_split * 100:
            return "B"
        else:
            return "A"
    
    def log_result(self, variant: str, entity_id: str,
                   prediction: float, actual: float = None):
        """Log prediction result for a variant."""
        result = {
            'entity_id': entity_id,
            'prediction': prediction,
            'actual': actual,
            'timestamp': datetime.now()
        }
        
        if variant == "A":
            self.results_a.append(result)
        else:
            self.results_b.append(result)
    
    def analyze_results(self) -> Dict:
        """Perform statistical analysis of A/B test results."""
        
        # Filter to only completed results
        completed_a = [r for r in self.results_a if r['actual'] is not None]
        completed_b = [r for r in self.results_b if r['actual'] is not None]
        
        if len(completed_a) < self.config.min_samples or \
           len(completed_b) < self.config.min_samples:
            return {
                'status': 'insufficient_data',
                'samples_a': len(completed_a),
                'samples_b': len(completed_b),
                'required': self.config.min_samples
            }
        
        # Extract metrics
        values_a = [r['actual'] for r in completed_a]
        values_b = [r['actual'] for r in completed_b]
        
        # Calculate primary metric
        if self.config.primary_metric == "conversion_rate":
            metric_a = np.mean(values_a)
            metric_b = np.mean(values_b)
            
            # Chi-squared test for proportions
            contingency = np.array([
                [sum(values_a), len(values_a) - sum(values_a)],
                [sum(values_b), len(values_b) - sum(values_b)]
            ])
            chi2, p_value, _, _ = stats.chi2_contingency(contingency)
        
        else:
            # Continuous metric - t-test
            metric_a = np.mean(values_a)
            metric_b = np.mean(values_b)
            
            t_stat, p_value = stats.ttest_ind(values_a, values_b)
        
        # Calculate confidence interval
        se_a = np.std(values_a) / np.sqrt(len(values_a))
        se_b = np.std(values_b) / np.sqrt(len(values_b))
        
        ci_lower = (metric_b - metric_a) - 1.96 * np.sqrt(se_a**2 + se_b**2)
        ci_upper = (metric_b - metric_a) + 1.96 * np.sqrt(se_a**2 + se_b**2)
        
        # Determine winner
        significant = p_value < self.config.significance_level
        winner = None
        if significant:
            if metric_b > metric_a:
                winner = "B"
            else:
                winner = "A"
        
        # Calculate lift
        lift = (metric_b - metric_a) / metric_a * 100 if metric_a > 0 else 0
        
        # Calculate required sample size for future tests
        required_n = self._calculate_required_sample_size(
            metric_a, effect_size=0.05
        )
        
        return {
            'status': 'completed' if significant else 'running',
            'significant': significant,
            'winner': winner,
            'p_value': p_value,
            'metric_a': metric_a,
            'metric_b': metric_b,
            'lift': lift,
            'confidence_interval': [ci_lower, ci_upper],
            'samples_a': len(completed_a),
            'samples_b': len(completed_b),
            'duration_days': (datetime.now() - self.start_time).days,
            'required_samples_per_variant': required_n,
        }
    
    def _calculate_required_sample_size(self, baseline_rate: float,
                                        effect_size: float,
                                        power: float = 0.8) -> int:
        """Calculate required sample size for statistical power."""
        
        alpha = self.config.significance_level
        beta = 1 - power
        
        p1 = baseline_rate
        p2 = baseline_rate * (1 + effect_size)
        
        pooled_p = (p1 + p2) / 2
        
        z_alpha = stats.norm.ppf(1 - alpha/2)
        z_beta = stats.norm.ppf(power)
        
        numerator = (z_alpha * np.sqrt(2 * pooled_p * (1 - pooled_p)) + 
                    z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2)))**2
        
        denominator = (p2 - p1)**2
        
        return int(np.ceil(numerator / denominator))
    
    def should_stop_early(self) -> Tuple[bool, str]:
        """Check if test should stop early (for ethical/resource reasons)."""
        
        # Stop if p-value is very significant (p < 0.001)
        results = self.analyze_results()
        
        if results.get('p_value', 1) < 0.001:
            return True, "Highly significant result detected"
        
        # Stop if duration exceeds max
        if results.get('duration_days', 0) >= self.config.duration_days:
            return True, "Maximum duration reached"
        
        return False, "Continue test"

A/B Testing Traffic Management

# kubernetes/ab-testing.yaml
apiVersion: networking.istio.io/v1beta1
kind: VirtualService
metadata:
  name: model-serving-vs
spec:
  hosts:
  - model-serving
  http:
  - route:
    - destination:
        host: model-serving
        subset: model-a
      weight: 90
    - destination:
        host: model-serving
        subset: model-b
      weight: 10
    retries:
      attempts: 3
      perTryTimeout: 2s
    timeout: 10s
---
apiVersion: networking.istio.io/v1beta1
kind: DestinationRule
metadata:
  name: model-serving-dr
spec:
  host: model-serving
  subsets:
  - name: model-a
    labels:
      model-version: v1
    trafficPolicy:
      connectionPool:
        tcp:
          maxConnections: 100
        http:
          h2UpgradePolicy: DEFAULT
          http1MaxPendingRequests: 100
          http2MaxRequests: 1000
  - name: model-b
    labels:
      model-version: v2
    trafficPolicy:
      connectionPool:
        tcp:
          maxConnections: 50
        http:
          h2UpgradePolicy: DEFAULT
          http1MaxPendingRequests: 50
          http2MaxRequests: 500

Shadow Mode Deployment

Shadow Mode Implementation

import asyncio
from typing import Optional
import time

class ShadowModeServer:
    def __init__(self, primary_model, shadow_model):
        self.primary_model = primary_model
        self.shadow_model = shadow_model
        self.shadow_predictions = []
        self.primary_predictions = []
        self.comparison_results = []
    
    async def predict_with_shadow(self, features: dict) -> dict:
        """Run prediction with shadow model comparison."""
        
        # Run primary prediction (serves response)
        start_primary = time.time()
        primary_prediction = await self._run_prediction(
            self.primary_model, features
        )
        primary_latency = (time.time() - start_primary) * 1000
        
        # Run shadow prediction asynchronously (doesn't serve response)
        async def run_shadow():
            start_shadow = time.time()
            shadow_prediction = await self._run_prediction(
                self.shadow_model, features
            )
            shadow_latency = (time.time() - start_shadow) * 1000
            
            # Store comparison
            self.comparison_results.append({
                'primary_prediction': primary_prediction,
                'shadow_prediction': shadow_prediction,
                'primary_latency': primary_latency,
                'shadow_latency': shadow_latency,
                'timestamp': time.time()
            })
            
            return shadow_prediction
        
        # Fire and forget shadow prediction
        asyncio.create_task(run_shadow())
        
        return {
            'prediction': primary_prediction,
            'latency_ms': primary_latency,
            'model_version': self.primary_model.version
        }
    
    async def _run_prediction(self, model, features: dict) -> float:
        """Run prediction on a model."""
        import pandas as pd
        input_df = pd.DataFrame([features])
        prediction = model.predict(input_df)
        return float(prediction[0])
    
    def get_comparison_metrics(self) -> dict:
        """Get metrics comparing primary vs shadow."""
        
        if not self.comparison_results:
            return {'status': 'no_data'}
        
        primary_preds = [r['primary_prediction'] for r in self.comparison_results]
        shadow_preds = [r['shadow_prediction'] for r in self.comparison_results]
        
        # Calculate agreement metrics
        agreement = np.mean(
            np.array(primary_preds) == np.array(shadow_preds)
        )
        
        # Calculate correlation
        correlation = np.corrcoef(primary_preds, shadow_preds)[0, 1]
        
        # Calculate latency comparison
        primary_latencies = [r['primary_latency'] for r in self.comparison_results]
        shadow_latencies = [r['shadow_latency'] for r in self.comparison_results]
        
        return {
            'agreement_rate': agreement,
            'correlation': correlation,
            'primary_avg_latency': np.mean(primary_latencies),
            'shadow_avg_latency': np.mean(shadow_latencies),
            'total_comparisons': len(self.comparison_results),
            'mean_absolute_difference': np.mean(
                np.abs(np.array(primary_preds) - np.array(shadow_preds))
            )
        }

# Shadow mode Kubernetes deployment
SHADOW_DEPLOYMENT = """
apiVersion: apps/v1
kind: Deployment
metadata:
  name: model-shadow
spec:
  replicas: 2
  selector:
    matchLabels:
      app: ml-model
      mode: shadow
  template:
    metadata:
      labels:
        app: ml-model
        mode: shadow
    spec:
      containers:
      - name: model-server
        image: registry.example.com/ml-model:v2
        ports:
        - containerPort: 8080
        env:
        - name: MODEL_MODE
          value: "shadow"
        - name: PRIMARY_MODEL_URL
          value: "http://model-primary:8080"
        resources:
          requests:
            memory: "2Gi"
            cpu: "1000m"
          limits:
            memory: "4Gi"
            cpu: "2000m"
---
apiVersion: v1
kind: Service
metadata:
  name: model-shadow
spec:
  selector:
    app: ml-model
    mode: shadow
  ports:
  - port: 8080
    targetPort: 8080
"""

⚠️

Shadow mode deployments double your inference load. Monitor resource usage carefully and set appropriate resource limits. Consider sampling (10-20% of traffic) for cost optimization.

Edge Deployment Patterns

TensorFlow Lite Edge Deployment

import tensorflow as tf
import numpy as np
from typing import Dict

class EdgeModelConverter:
    def __init__(self, model_path: str):
        self.model = tf.keras.models.load_model(model_path)
    
    def convert_to_tflite(self, quantize: bool = True,
                          optimize: bool = True) -> bytes:
        """Convert model to TensorFlow Lite format."""
        
        converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
        
        if optimize:
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
        
        if quantize:
            def representative_dataset():
                for _ in range(100):
                    yield [np.random.randn(1, *self.model.input_shape[1:]).astype(np.float32)]
            
            converter.representative_dataset = representative_dataset
            converter.target_spec.supported_types = [tf.float16]
        
        tflite_model = converter.convert()
        
        # Save model
        output_path = "model_edge.tflite"
        with open(output_path, 'wb') as f:
            f.write(tflite_model)
        
        print(f"Model converted: {len(tflite_model) / 1024:.1f} KB")
        
        return tflite_model
    
    def convert_to_onnx(self, output_path: str = "model.onnx"):
        """Convert to ONNX format for cross-platform deployment."""
        import tf2onnx
        import onnx
        
        spec = (tf.TensorSpec(self.model.input_shape, tf.float32, name="input"),)
        output_path = output_path
        
        model_proto, _ = tf2onnx.convert.from_keras(
            self.model, input_signature=spec, output_path=output_path
        )
        
        print(f"ONNX model saved: {output_path}")
        
        return output_path
    
    def create_edge_inference_code(self, tflite_path: str):
        """Generate C++ inference code for edge devices."""
        
        cpp_code = f"""
// edge_inference.cpp
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include <vector>
#include <memory>

class EdgeModel {{
public:
    EdgeModel(const char* model_path) {{
        model_ = tflite::FlatBufferModel::BuildFromFile(model_path);
        tflite::ops::builtin::BuiltinOpResolver resolver;
        tflite::BuildInterpreter(model_, resolver, &interpreter_);
    }}
    
    std::vector<float> Predict(const std::vector<float>& input) {{
        // Resize input tensor
        interpreter_->ResizeInputTensor(0, {{1, static_cast<int>(input.size())}});
        interpreter_->AllocateTensors();
        
        // Copy input data
        float* input_ptr = interpreter_->typed_tensor<float>(0);
        std::copy(input.begin(), input.end(), input_ptr);
        
        // Run inference
        interpreter_->Invoke();
        
        // Get output
        float* output_ptr = interpreter_->typed_tensor<float>(0);
        int output_size = interpreter_->tensor(0)->bytes / sizeof(float);
        
        return std::vector<float>(output_ptr, output_ptr + output_size);
    }}

private:
    std::unique_ptr<tflite::FlatBufferModel> model_;
    std::unique_ptr<tflite::Interpreter> interpreter_;
}};
"""
        
        with open("edge_inference.cpp", "w") as f:
            f.write(cpp_code)

Model Serving with NVIDIA Triton

Triton Configuration

# model_repository/fraud_detection/config.pbtxt
name: "fraud_detection"
platform: "ensemble"
max_batch_size: 64

input [
  {
    name: "INPUT"
    data_type: TYPE_FP32
    dims: [ 128 ]
  }
]

output [
  {
    name: "OUTPUT"
    data_type: TYPE_FP32
    dims: [ 1 ]
  }
]

ensemble_scheduling {
  step [
    {
      model_name: "preprocessing"
      model_version: -1
      input_map {
        key: "INPUT"
        value: "INPUT"
      }
      output_map {
        key: "OUTPUT"
        value: "preprocessed"
      }
    },
    {
      model_name: "xgboost_model"
      model_version: -1
      input_map {
        key: "INPUT"
        value: "preprocessed"
      }
      output_map {
        key: "OUTPUT"
        value: "OUTPUT"
      }
    }
  ]
}

dynamic_batching {
  preferred_batch_size: [ 16, 32, 64 ]
  max_queue_delay_microseconds: 100
}

instance_group [
  {
    count: 2
    kind: KIND_GPU
    gpus: [ 0 ]
  }
]

Triton Client

import tritonclient.grpc as grpcclient
import numpy as np
from typing import Dict, List
import time

class TritonClient:
    def __init__(self, server_url: str = "localhost:8001"):
        self.client = grpcclient.InferenceServerClient(url=server_url)
    
    def predict(self, input_data: np.ndarray, 
                model_name: str = "fraud_detection") -> Dict:
        """Send prediction request to Triton."""
        
        # Prepare input
        inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
        inputs[0].set_data_from_numpy(input_data)
        
        # Prepare output
        outputs = [grpcclient.InferRequestedOutput("OUTPUT")]
        
        # Send request
        start_time = time.time()
        response = self.client.infer(
            model_name=model_name,
            inputs=inputs,
            outputs=outputs
        )
        latency_ms = (time.time() - start_time) * 1000
        
        # Get results
        output = response.as_numpy("OUTPUT")
        
        return {
            'prediction': float(output[0]),
            'latency_ms': latency_ms,
            'model_name': model_name
        }
    
    def batch_predict(self, batch_data: np.ndarray,
                      model_name: str = "fraud_detection") -> List[Dict]:
        """Send batch prediction request."""
        
        batch_size = batch_data.shape[0]
        
        inputs = [grpcclient.InferInput("INPUT", batch_data.shape, "FP32")]
        inputs[0].set_data_from_numpy(batch_data)
        
        outputs = [grpcclient.InferRequestedOutput("OUTPUT")]
        
        start_time = time.time()
        response = self.client.infer(
            model_name=model_name,
            inputs=inputs,
            outputs=outputs
        )
        total_latency_ms = (time.time() - start_time) * 1000
        
        output = response.as_numpy("OUTPUT")
        
        results = []
        per_item_latency = total_latency_ms / batch_size
        
        for i in range(batch_size):
            results.append({
                'prediction': float(output[i]),
                'latency_ms': per_item_latency,
                'batch_index': i
            })
        
        return results
    
    def get_model_metadata(self, model_name: str) -> Dict:
        """Get model metadata from Triton."""
        metadata = self.client.get_model_metadata(model_name=model_name)
        
        return {
            'name': metadata.name,
            'version': metadata.version,
            'platform': metadata.platform,
            'inputs': [
                {
                    'name': inp.name,
                    'datatype': inp.datatype,
                    'shape': inp.shape
                }
                for inp in metadata.inputs
            ],
            'outputs': [
                {
                    'name': out.name,
                    'datatype': out.datatype,
                    'shape': out.shape
                }
                for out in metadata.outputs
            ]
        }

ℹ️

NVIDIA Triton supports dynamic batching, model ensembles, and multi-GPU inference. Use it for high-throughput serving with sub-millisecond latency requirements.

Summary

Model serving patterns depend on your requirements:

  1. Batch Inference: PySpark for large-scale offline processing
  2. Real-Time REST: FastAPI/Flask for synchronous web APIs
  3. Real-Time gRPC: Triton for high-throughput, low-latency serving
  4. A/B Testing: Statistical frameworks with traffic splitting
  5. Shadow Mode: Validate new models without affecting users
  6. Edge Deployment: TFLite/ONNX for on-device inference

Choose the pattern that matches your latency, throughput, and cost requirements.

Advertisement