Model Registry & Versioning
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Model Registry Fundamentals
A model registry provides centralized model management with versioning, lineage tracking, and lifecycle management.
βΉοΈ
Netflix's model registry tracks over 1000 production models with full lineage from data to deployment.
MLflow Model Registry
# mlflow_registry.py
from mlflow.tracking import MlflowClient
from mlflow.entities import ModelVersion, Run
from mlflow.store.model_registry import models_uri_to_latest
import mlflow.pytorch
import mlflow.sklearn
from typing import Optional, Dict, List
from dataclasses import dataclass
from enum import Enum
class ModelStage(Enum):
NONE = "None"
STAGING = "Staging"
PRODUCTION = "Production"
ARCHIVED = "Archived"
@dataclass
class ModelMetadata:
name: str
version: str
stage: str
description: str
tags: Dict[str, str]
run_id: str
created_at: str
updated_at: str
class ModelRegistryManager:
def __init__(self, tracking_uri: str = "http://localhost:5000"):
self.client = MlflowClient(tracking_uri=tracking_uri)
def register_model(self, run_id: str, model_name: str, description: str = "") -> ModelVersion:
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, model_name)
self.client.update_model_version(
name=model_name,
version=model_version.version,
description=description
)
return model_version
def transition_model_stage(
self,
model_name: str,
version: str,
stage: ModelStage,
archive_existing: bool = True
):
self.client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage.value,
archive_existing_versions=archive_existing
)
def add_model_tag(self, model_name: str, version: str, key: str, value: str):
self.client.set_model_version_tag(
name=model_name,
version=version,
key=key,
value=value
)
def get_model_version(self, model_name: str, version: str) -> ModelVersion:
return self.client.get_model_version(model_name, version)
def get_latest_versions(self, model_name: str, stages: Optional[List[str]] = None) -> List[ModelVersion]:
return self.client.get_latest_versions(model_name, stages)
def get_model_version_by_stage(self, model_name: str, stage: ModelStage) -> Optional[ModelVersion]:
versions = self.get_latest_versions(model_name, [stage.value])
return versions[0] if versions else None
def compare_model_versions(self, model_name: str, version_a: str, version_b: str) -> Dict:
mv_a = self.get_model_version(model_name, version_a)
mv_b = self.get_model_version(model_name, version_b)
run_a = self.client.get_run(mv_a.run_id)
run_b = self.client.get_run(mv_b.run_id)
return {
"version_a": {
"stage": mv_a.current_stage,
"metrics": run_a.data.metrics,
"params": run_a.data.params,
},
"version_b": {
"stage": mv_b.current_stage,
"metrics": run_b.data.metrics,
"params": run_b.data.params,
}
}
def delete_model_version(self, model_name: str, version: str):
self.client.delete_model_version(model_name, version)
# Usage
registry = ModelRegistryManager()
run_id = "abc123def456"
model_version = registry.register_model(
run_id=run_id,
model_name="churn-predictor",
description="Random Forest model for customer churn prediction v2.1"
)
registry.add_model_tag("churn-predictor", model_version.version, "team", "ml-ops")
registry.add_model_tag("churn-predictor", model_version.version, "dataset_version", "2024-01-15")
registry.transition_model_stage("churn-predictor", model_version.version, ModelStage.STAGING)
Custom Model Registry
# custom_registry.py
import json
import pickle
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from enum import Enum
import shutil
class ModelStage(Enum):
DEVELOPMENT = "development"
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"
@dataclass
class ModelVersionInfo:
version: str
model_name: str
stage: ModelStage
description: str
tags: Dict[str, str]
metrics: Dict[str, float]
params: Dict[str, Any]
run_id: str
artifact_path: str
created_at: str
updated_at: str
created_by: str
class ModelRegistry:
def __init__(self, registry_path: str = "./model_registry"):
self.registry_path = Path(registry_path)
self.registry_path.mkdir(parents=True, exist_ok=True)
self._load_index()
def _load_index(self):
index_file = self.registry_path / "index.json"
if index_file.exists():
with open(index_file) as f:
self.index = json.load(f)
else:
self.index = {"models": {}}
def _save_index(self):
with open(self.registry_path / "index.json", "w") as f:
json.dump(self.index, f, indent=2)
def register_model(
self,
model: Any,
model_name: str,
description: str = "",
metrics: Optional[Dict[str, float]] = None,
params: Optional[Dict[str, Any]] = None,
tags: Optional[Dict[str, str]] = None,
run_id: Optional[str] = None,
created_by: str = "system"
) -> ModelVersionInfo:
version = self._get_next_version(model_name)
model_dir = self.registry_path / model_name / f"v{version}"
model_dir.mkdir(parents=True, exist_ok=True)
artifact_path = model_dir / "model.pkl"
with open(artifact_path, "wb") as f:
pickle.dump(model, f)
model_hash = self._compute_hash(artifact_path)
version_info = ModelVersionInfo(
version=str(version),
model_name=model_name,
stage=ModelStage.DEVELOPMENT,
description=description,
tags=tags or {},
metrics=metrics or {},
params=params or {},
run_id=run_id or "",
artifact_path=str(artifact_path),
created_at=datetime.now().isoformat(),
updated_at=datetime.now().isoformat(),
created_by=created_by,
)
if model_name not in self.index["models"]:
self.index["models"][model_name] = {"versions": {}}
self.index["models"][model_name]["versions"][str(version)] = {
"stage": version_info.stage.value,
"hash": model_hash,
"created_at": version_info.created_at,
}
metadata_file = model_dir / "metadata.json"
with open(metadata_file, "w") as f:
json.dump(asdict(version_info), f, indent=2)
self._save_index()
return version_info
def promote_model(self, model_name: str, version: str, target_stage: ModelStage):
model_dir = self.registry_path / model_name / f"v{version}"
metadata_file = model_dir / "metadata.json"
with open(metadata_file) as f:
metadata = json.load(f)
metadata["stage"] = target_stage.value
metadata["updated_at"] = datetime.now().isoformat()
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
self.index["models"][model_name]["versions"][version]["stage"] = target_stage.value
self._save_index()
def load_model(self, model_name: str, version: Optional[str] = None, stage: Optional[ModelStage] = None):
if stage:
version = self._get_version_by_stage(model_name, stage)
if version is None:
raise ValueError(f"No model found for {model_name}")
model_path = self.registry_path / model_name / f"v{version}" / "model.pkl"
with open(model_path, "rb") as f:
return pickle.load(f)
def list_models(self) -> List[str]:
return list(self.index["models"].keys())
def list_versions(self, model_name: str) -> List[Dict]:
if model_name not in self.index["models"]:
return []
versions = self.index["models"][model_name]["versions"]
return [{"version": v, **info} for v, info in versions.items()]
def get_latest_production_model(self, model_name: str) -> Optional[Dict]:
versions = self.list_versions(model_name)
prod_versions = [v for v in versions if v["stage"] == "production"]
if not prod_versions:
return None
return max(prod_versions, key=lambda x: int(x["version"]))
def _get_next_version(self, model_name: str) -> int:
if model_name not in self.index["models"]:
return 1
versions = self.index["models"][model_name]["versions"]
if not versions:
return 1
return max(int(v) for v in versions.keys()) + 1
def _get_version_by_stage(self, model_name: str, stage: ModelStage) -> Optional[str]:
versions = self.list_versions(model_name)
for v in versions:
if v["stage"] == stage.value:
return v["version"]
return None
def _compute_hash(self, file_path: Path) -> str:
hasher = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hasher.update(chunk)
return hasher.hexdigest()[:16]
# Usage
registry = ModelRegistry("./my_model_registry")
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100)
model.fit([[1, 2], [3, 4]], [0, 1])
version_info = registry.register_model(
model=model,
model_name="fraud-detector",
description="Random Forest fraud detection model",
metrics={"accuracy": 0.95, "f1": 0.93},
params={"n_estimators": 100},
tags={"team": "security", "priority": "high"}
)
registry.promote_model("fraud-detector", "1", ModelStage.PRODUCTION)
loaded_model = registry.load_model("fraud-detector", stage=ModelStage.PRODUCTION)
Follow-Up Questions
- How do you handle model rollback in production?
- What metadata should be tracked for regulatory compliance?
- How would you implement A/B testing between model versions?
- What are the implications of model registry design on team collaboration?