πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Model Serving Patterns

MLOpsModel Deployment⭐ Premium

Advertisement

Model Serving Patterns

Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe

Serving Architecture

Model serving must balance latency, throughput, cost, and reliability requirements.

ℹ️

Google's TFServing handles 100+ billion predictions per day across all serving patterns.

FastAPI Model Server

# model_server.py
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import numpy as np
import pickle
import time
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="ML Model Server",
    description="Production model serving API",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class PredictionRequest(BaseModel):
    features: List[float]
    request_id: Optional[str] = None

class PredictionResponse(BaseModel):
    prediction: float
    probability: Optional[float] = None
    model_version: str
    latency_ms: float
    request_id: Optional[str] = None

class BatchPredictionRequest(BaseModel):
    instances: List[List[float]]
    request_id: Optional[str] = None

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    uptime_seconds: float
    total_predictions: int

class ModelManager:
    def __init__(self):
        self.model = None
        self.model_version = "1.0.0"
        self.load_time = None
        self.total_predictions = 0
        self.executor = ThreadPoolExecutor(max_workers=4)

    def load_model(self, model_path: str):
        with open(model_path, "rb") as f:
            self.model = pickle.load(f)
        self.load_time = time.time()
        logger.info(f"Model loaded from {model_path}")

    def predict(self, features: np.ndarray) -> Dict[str, Any]:
        start_time = time.time()

        prediction = self.model.predict(features.reshape(1, -1))[0]
        probability = None
        if hasattr(self.model, "predict_proba"):
            probability = float(self.model.predict_proba(features.reshape(1, -1)).max())

        latency = (time.time() - start_time) * 1000
        self.total_predictions += 1

        return {
            "prediction": float(prediction),
            "probability": probability,
            "latency_ms": latency
        }

    def predict_batch(self, instances: List[np.ndarray]) -> List[Dict[str, Any]]:
        start_time = time.time()
        predictions = self.model.predict(np.array(instances))
        probabilities = None
        if hasattr(self.model, "predict_proba"):
            probabilities = self.model.predict_proba(np.array(instances))

        results = []
        for i, pred in enumerate(predictions):
            result = {
                "prediction": float(pred),
                "probability": float(probabilities[i].max()) if probabilities is not None else None,
            }
            results.append(result)

        latency = (time.time() - start_time) * 1000
        self.total_predictions += len(instances)
        return results

model_manager = ModelManager()

@app.on_event("startup")
async def startup_event():
    model_manager.load_model("model.pkl")

@app.get("/health", response_model=HealthResponse)
async def health_check():
    return HealthResponse(
        status="healthy",
        model_loaded=model_manager.model is not None,
        uptime_seconds=time.time() - model_manager.load_time,
        total_predictions=model_manager.total_predictions
    )

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    if model_manager.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    features = np.array(request.features)
    result = model_manager.predict(features)

    return PredictionResponse(
        prediction=result["prediction"],
        probability=result["probability"],
        model_version=model_manager.model_version,
        latency_ms=result["latency_ms"],
        request_id=request.request_id
    )

@app.post("/predict/batch")
async def predict_batch(request: BatchPredictionRequest):
    if model_manager.model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    instances = [np.array(inst) for inst in request.instances]
    results = model_manager.predict_batch(instances)

    return {
        "predictions": results,
        "model_version": model_manager.model_version,
        "batch_size": len(results),
        "request_id": request.request_id
    }

@app.post("/model/reload")
async def reload_model(model_path: str = "model.pkl"):
    model_manager.load_model(model_path)
    return {"status": "reloaded", "version": model_manager.model_version}

TensorFlow Serving

# tf_serving_client.py
import requests
import numpy as np
import json
from typing import Dict, List, Any
from dataclasses import dataclass

@dataclass
class TFServerConfig:
    host: str
    port: int
    model_name: str
    model_version: int

class TFServingClient:
    def __init__(self, config: TFServerConfig):
        self.config = config
        self.base_url = f"http://{config.host}:{config.port}/v1/models/{config.model_name}"

    def predict(self, input_data: np.ndarray) -> Dict[str, Any]:
        payload = {
            "instances": input_data.tolist()
        }
        response = requests.post(
            f"{self.base_url}/versions/{self.config.model_version}:predict",
            json=payload,
            headers={"Content-Type": "application/json"}
        )
        response.raise_for_status()
        return response.json()

    def predict_with_signature(self, input_data: np.ndarray, signature_name: str) -> Dict[str, Any]:
        payload = {
            "instances": input_data.tolist(),
            "signature_name": signature_name
        }
        response = requests.post(
            f"{self.base_url}/versions/{self.config.model_version}:predict",
            json=payload
        )
        response.raise_for_status()
        return response.json()

    def get_model_metadata(self) -> Dict[str, Any]:
        response = requests.get(f"{self.base_url}/metadata")
        response.raise_for_status()
        return response.json()

    def health_check(self) -> bool:
        try:
            response = requests.get(f"{self.base_url}/versions/{self.config.model_version}")
            return response.status_code == 200
        except Exception:
            return False


# Usage
config = TFServerConfig(
    host="localhost",
    port=8501,
    model_name="image_classifier",
    model_version=1
)

client = TFServingClient(config)
input_data = np.random.randn(1, 224, 224, 3).astype(np.float32)
predictions = client.predict(input_data)

Load Balancer Configuration

# nginx_load_balancer.conf
upstream ml_backend {
    least_conn;
    server ml-server-1:8000 weight=3;
    server ml-server-2:8000 weight=3;
    server ml-server-3:8000 weight=2;

    keepalive 32;
}

server {
    listen 80;
    server_name ml-api.example.com;

    location /v1/models/ {
        proxy_pass http://ml_backend;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;

        proxy_connect_timeout 5s;
        proxy_read_timeout 30s;
        proxy_send_timeout 30s;

        proxy_next_upstream error timeout http_502 http_503;
        proxy_next_upstream_tries 3;
    }

    location /health {
        proxy_pass http://ml_backend;
    }

    location /metrics {
        proxy_pass http://prometheus:9090;
    }
}

Follow-Up Questions

  1. How would you implement model canary deployments?
  2. What are the trade-offs between batching and real-time inference?
  3. How do you handle model fallback when the primary model fails?
  4. What caching strategies work best for ML predictions?

Advertisement