πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Experiment Tracking: MLflow, W&B, Neptune, DVC

MLOpsExperiment Tracking⭐ Premium

Advertisement

Interview Question (Medium) β€” Asked at: Google, Microsoft, Amazon, Netflix

"How do you track and manage ML experiments at scale? Compare MLflow, W&B, and Neptune. How do you ensure reproducibility across teams and handle model versioning in production?"

Experiment Tracking Architecture

Experiment tracking captures all metadata, code, data, and artifacts associated with ML experiments. At scale, this requires a centralized system with versioning, search, and collaboration features.

Core Components

Architecture Diagram
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚               Experiment Tracking Architecture                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
β”‚  β”‚   ML Client  β”‚    β”‚   ML Client  β”‚    β”‚   ML Client  β”‚     β”‚
β”‚  β”‚  (Python)    β”‚    β”‚   (R/Java)   β”‚    β”‚  (Notebook)  β”‚     β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜     β”‚
β”‚         β”‚                   β”‚                   β”‚              β”‚
β”‚         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚                             β–Ό                                  β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚  β”‚              Tracking Server / API                       β”‚   β”‚
β”‚  β”‚  β€’ Experiment Metadata    β€’ Model Registry              β”‚   β”‚
β”‚  β”‚  β€’ Run Metadata           β€’ Artifact Storage            β”‚   β”‚
β”‚  β”‚  β€’ Metrics & Params       β€’ Lineage Tracking            β”‚   β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚         β”‚                   β”‚                   β”‚              β”‚
β”‚         β–Ό                   β–Ό                   β–Ό              β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”        β”‚
β”‚  β”‚Metadata  β”‚      β”‚   Artifact   β”‚      β”‚  Metric  β”‚        β”‚
β”‚  β”‚Database  β”‚      β”‚   Store      β”‚      β”‚  Store   β”‚        β”‚
β”‚  β”‚(Postgres)β”‚      β”‚  (S3/GCS)    β”‚      β”‚(Timescale)β”‚       β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

MLflow Deep Dive

MLflow Tracking Server Setup

# mlflow-server-config.yaml
server:
  host: 0.0.0.0
  port: 5000
  
  # Backend store for metadata
  backend_store:
    type: sqlalchemy
    uri: postgresql://user:password@postgres:5432/mlflow
    
  # Artifact store for models, data, plots
  artifact_store:
    type: s3
    s3:
      bucket: mlflow-artifacts
      prefix: experiments/
      
  # Authentication
  auth:
    type: oauth2
    oauth2:
      provider: github
      client_id: ${GITHUB_CLIENT_ID}
      client_secret: ${GITHUB_CLIENT_SECRET}
      
  # Experiment organization
  experiments:
    default_location: /experiments
    
  # Model registry
  model_registry:
    type: sqlalchemy
    uri: postgresql://user:password@postgres:5432/mlflow_registry

MLflow Experiment Tracking Implementation

import mlflow
import mlflow.sklearn
import mlflow.xgboost
import mlflow.lightgbm
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
import pandas as pd
import numpy as np
import json
import hashlib
from datetime import datetime
from typing import Dict, Any, List, Optional
from pathlib import Path
import joblib
import shap

class MLflowExperimentTracker:
    def __init__(self, experiment_name: str, tracking_uri: str = None):
        if tracking_uri:
            mlflow.set_tracking_uri(tracking_uri)
        
        mlflow.set_experiment(experiment_name)
        self.client = MlflowClient()
        self.experiment_name = experiment_name
    
    def _compute_data_fingerprint(self, df: pd.DataFrame) -> str:
        """Compute deterministic fingerprint for dataset."""
        return hashlib.sha256(
            pd.util.hash_pandas_object(df).values.tobytes()
        ).hexdigest()[:16]
    
    def _log_code_version(self):
        """Log git commit and code version."""
        import git
        repo = git.Repo(search_parent_directories=True)
        
        mlflow.log_param("git_commit", repo.head.commit.hexsha)
        mlflow.log_param("git_branch", repo.active_branch.name)
        mlflow.log_param("git_dirty", repo.is_dirty())
        
        if repo.is_dirty():
            diff = repo.index.diff(None)
            mlflow.log_text(
                str(diff), 
                "artifacts/code_changes.diff"
            )
    
    def start_run(self, run_name: str = None, 
                  tags: Dict[str, str] = None):
        """Start a tracked MLflow run."""
        
        if run_name is None:
            run_name = f"run_{datetime.now():%Y%m%d_%H%M%S}"
        
        mlflow.start_run(run_name=run_name, tags=tags or {})
        
        # Auto-log code version
        self._log_code_version()
        
        return mlflow.active_run()
    
    def log_parameters(self, params: Dict[str, Any]):
        """Log hyperparameters with type preservation."""
        for key, value in params.items():
            if isinstance(value, dict):
                mlflow.log_param(key, json.dumps(value))
            elif isinstance(value, list):
                mlflow.log_param(key, json.dumps(value))
            else:
                mlflow.log_param(key, value)
    
    def log_metrics(self, metrics: Dict[str, float], step: int = None):
        """Log metrics with optional step for time series."""
        for key, value in metrics.items():
            mlflow.log_metric(key, value, step=step)
    
    def log_dataset(self, df: pd.DataFrame, name: str, 
                    tags: Dict[str, str] = None):
        """Log dataset with versioning."""
        fingerprint = self._compute_data_fingerprint(df)
        
        mlflow.log_param(f"dataset_{name}_fingerprint", fingerprint)
        mlflow.log_param(f"dataset_{name}_rows", len(df))
        mlflow.log_param(f"dataset_{name}_columns", len(df.columns))
        
        # Save dataset
        path = f"datasets/{name}"
        df.to_parquet(f"{path}.parquet", index=False)
        mlflow.log_artifact(f"{path}.parquet")
        
        # Save schema
        schema = {
            'columns': list(df.columns),
            'dtypes': {col: str(dtype) for col, dtype in df.dtypes.items()},
            'shape': df.shape,
            'fingerprint': fingerprint,
            'statistics': df.describe().to_dict()
        }
        
        with open(f"{path}_schema.json", 'w') as f:
            json.dump(schema, f, indent=2, default=str)
        mlflow.log_artifact(f"{path}_schema.json")
    
    def log_model(self, model, model_name: str, 
                  feature_names: List[str] = None,
                  signature: bool = True):
        """Log model with signature and flavor-specific logging."""
        
        # Auto-detect model type
        model_type = type(model).__name__
        
        if 'XGB' in model_type:
            mlflow.xgboost.log_model(model, model_name)
        elif 'LGBM' in model_type:
            mlflow.lightgbm.log_model(model, model_name)
        elif hasattr(model, 'fit'):
            mlflow.sklearn.log_model(model, model_name)
        else:
            mlflow.pyfunc.log_model(
                model_name,
                python_model=PythonModelWrapper(model)
            )
        
        # Log model size
        import tempfile
        with tempfile.NamedTemporaryFile(suffix='.pkl') as f:
            joblib.dump(model, f.name)
            size_mb = Path(f.name).stat().st_size / (1024 * 1024)
            mlflow.log_metric("model_size_mb", size_mb)
    
    def log_explainer(self, model, X_background: pd.DataFrame, 
                      X_explain: pd.DataFrame, model_name: str):
        """Log SHAP explanations."""
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_explain)
        
        # Save SHAP values
        shap_df = pd.DataFrame(shap_values, columns=X_explain.columns)
        shap_df.to_parquet(f"shap_values_{model_name}.parquet")
        mlflow.log_artifact(f"shap_values_{model_name}.parquet")
        
        # Log feature importance
        feature_importance = np.abs(shap_values).mean(axis=0)
        importance_dict = dict(zip(X_explain.columns, feature_importance))
        
        mlflow.log_dict(importance_dict, f"feature_importance_{model_name}.json")
    
    def compare_runs(self, metric: str = "auc_roc", 
                     n_runs: int = 10) -> pd.DataFrame:
        """Compare recent runs by metric."""
        experiment = self.client.get_experiment_by_name(
            self.experiment_name
        )
        
        runs = self.client.search_runs(
            experiment_ids=[experiment.experiment_id],
            order_by=[f"metrics.{metric} DESC"],
            max_results=n_runs
        )
        
        results = []
        for run in runs:
            result = {
                'run_id': run.info.run_id,
                'run_name': run.info.run_name,
                'start_time': run.info.start_time,
                'status': run.info.status,
                'metric': run.data.metrics.get(metric),
            }
            result.update(run.data.params)
            results.append(result)
        
        return pd.DataFrame(results)
    
    def promote_model(self, run_id: str, 
                      model_name: str,
                      stage: str = "Production",
                      description: str = ""):
        """Promote model to a specific stage."""
        model_uri = f"runs:/{run_id}/model"
        
        result = mlflow.register_model(
            model_uri=model_uri,
            name=model_name
        )
        
        self.client.transition_model_version_stage(
            name=model_name,
            version=result.version,
            stage=stage
        )
        
        self.client.update_model_version(
            name=model_name,
            version=result.version,
            description=description
        )
        
        return result.version

ℹ️

MLflow's model registry provides stage transitions (Staging β†’ Production β†’ Archived) with full lineage tracking. Use this for governance and audit trails in regulated industries.

Weights & Biases Integration

W&B Project Setup

import wandb
from wandb.integration.xgboost import WandbCallback
import xgboost as xgb
import pandas as pd
import numpy as np
from typing import Dict, Any

class WNBExperimentTracker:
    def __init__(self, project: str, entity: str = None):
        self.project = project
        self.entity = entity
    
    def init_run(self, config: Dict[str, Any], 
                 job_type: str = "training"):
        """Initialize W&B run."""
        wandb.init(
            project=self.project,
            entity=self.entity,
            config=config,
            job_type=job_type,
            tags=config.get('tags', []),
            notes=config.get('notes', '')
        )
        
        return wandb.run
    
    def log_training(self, model, X_train, y_train, 
                     X_val, y_val, config: dict):
        """Log complete training with W&B."""
        
        # Create datasets
        wandb_train = wandb.Table(dataframe=X_train.head(1000))
        wandb_val = wandb.Table(dataframe=X_val.head(1000))
        
        wandb.log({
            "train_data": wandb_train,
            "val_data": wandb_val,
        })
        
        # Train with automatic logging
        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)
        
        callbacks = [WandbCallback()]
        
        model = xgb.train(
            config,
            dtrain,
            num_boost_round=1000,
            evals=[(dtrain, "train"), (dval, "val")],
            callbacks=callbacks,
            early_stopping_rounds=50
        )
        
        # Log feature importance
        importance = model.get_score(importance_type='gain')
        importance_table = wandb.Table(
            columns=["feature", "importance"],
            data=[[k, v] for k, v in sorted(
                importance.items(), key=lambda x: x[1], reverse=True
            )]
        )
        
        wandb.log({"feature_importance": importance_table})
        
        # Log model artifact
        wandb.save("model.json")
        
        return model
    
    def log_metrics_series(self, metrics_dict: Dict[str, list]):
        """Log time series metrics."""
        for name, values in metrics_dict.items():
            for step, value in enumerate(values):
                wandb.log({name: value}, step=step)
    
    def log_confusion_matrix(self, y_true, y_pred, labels=None):
        """Log interactive confusion matrix."""
        wandb.log({
            "confusion_matrix": wandb.plot.confusion_matrix(
                y_true=y_true,
                preds=y_pred,
                labels=labels
            )
        })
    
    def log_roc_curve(self, y_true, y_probas):
        """Log ROC curve."""
        wandb.log({
            "roc_curve": wandb.plot.roc_curve(
                y_true=y_true,
                y_probas=y_probas
            )
        })
    
    def log_pr_curve(self, y_true, y_probas):
        """Log Precision-Recall curve."""
        wandb.log({
            "pr_curve": wandb.plot.pr_curve(
                y_true=y_true,
                y_probas=y_probas
            )
        })
    
    def create_model_card(self, model_name: str, metrics: dict,
                          config: dict, description: str):
        """Create and log model card."""
        metrics_rows = "\n".join(
            f"| {k} | {v:.4f} |" for k, v in metrics.items()
        )
        model_card = (
            f"# {model_name}\n\n"
            "## Description\n\n"
            "The MLflow run is configured with a unique experiment name, run name, and tags for tracking.\n\n"
            "## Training Configuration\n\n"
            "    experiment_name: fraud-detection-v2\n"
            "    run_name: xgboost-baseline-001\n"
            "    tags:\n"
            "      model_type: xgboost\n"
            "      dataset: transactions-v3\n"
            "      author: ml-team\n\n"
            "## Evaluation Metrics\n"
            "| Metric | Value |\n"
            "|--------|-------|\n"
            f"{metrics_rows}\n\n"
            "## Training Data\n"
            "- Dataset fingerprint: abc123def456\n"
            "- Number of samples: 100000\n"
            "- Features: 42\n\n"
            "## Limitations\n"
            "- Trained on historical data from 2023-01-01 to 2023-12-31\n"
            "- May not generalize to out-of-distribution data\n\n"
            "## Ethical Considerations\n"
            "- Model reviewed for fairness on protected attributes\n"
            "- Bias testing performed on 2023-11-15\n"
        )
        
        with open("model_card.md", "w") as f:
            f.write(model_card)
        
        wandb.save("model_card.md")
        
        wandb.log({
            "model_card": wandb.Table(
                columns=["section", "content"],
                data=[
                    ["Description", description],
                    ["Metrics", str(metrics)],
                    ["Config", str(config)]
                ]
            )
        })

DVC for Data and Model Versioning

DVC Pipeline Configuration

# dvc.yaml
stages:
  prepare:
    cmd: python src/data/prepare.py
    deps:
      - src/data/prepare.py
      - data/raw/dataset.csv
    params:
      - data.yaml:test_size
      - data.yaml:random_state
    outs:
      - data/prepared/train.csv
      - data/prepared/test.csv
    metrics:
      - metrics/data_stats.json:
          cache: false

  featurize:
    cmd: python src/features/build.py
    deps:
      - src/features/build.py
      - data/prepared/train.csv
    outs:
      - data/features/train_features.npy
      - data/features/test_features.npy

  train:
    cmd: python src/models/train.py
    deps:
      - src/models/train.py
      - data/features/train_features.npy
    params:
      - model.yaml:n_estimators
      - model.yaml:max_depth
      - model.yaml:learning_rate
    outs:
      - models/model.pkl
    metrics:
      - metrics/train.json:
          cache: false

  evaluate:
    cmd: python src/models/evaluate.py
    deps:
      - src/models/evaluate.py
      - models/model.pkl
      - data/features/test_features.npy
    metrics:
      - metrics/evaluate.json:
          cache: false
    plots:
      - plots/confusion_matrix.json:
          cache: false
      - plots/roc_curve.json:
          cache: false

  deploy:
    cmd: python src/deployment/deploy.py
    deps:
      - src/deployment/deploy.py
      - models/model.pkl
    params:
      - deployment.yaml:
          - target_environment
          - canary_percentage

DVC Data Versioning Commands

# Initialize DVC in your Git repo
dvc init

# Add data to DVC tracking
dvc add data/training/dataset.csv
git add data/training/dataset.csv.dvc

# Create a DVC pipeline
dvc run -n prepare \
    -d src/data/prepare.py \
    -d data/raw/dataset.csv \
    -o data/prepared/train.csv \
    -o data/prepared/test.csv \
    --params data.yaml:test_size,data.yaml:random_state \
    python src/data/prepare.py

# Run the full pipeline
dvc repro

# Compare experiments
dvc plots show metrics/evaluate.json \
    --metrics=metrics/evaluate.json \
    -o plots/

# Push data to remote storage
dvc push

# Pull specific data version
dvc pull data/training/dataset.csv@v2.1

# Switch to different data version
git checkout data/training/dataset.csv.dvc@v2.1
dvc checkout

Model Versioning Strategies

Semantic Versioning for ML Models

from enum import Enum
from dataclasses import dataclass
from typing import Optional
import json

class ModelVersionType(Enum):
    MAJOR = "major"  # Breaking API changes, incompatible features
    MINOR = "minor"  # New features, backward compatible
    PATCH = "patch"  # Bug fixes, performance improvements

@dataclass
class ModelVersion:
    major: int
    minor: int
    patch: int
    pre_release: Optional[str] = None
    
    def __str__(self):
        version = f"{self.major}.{self.minor}.{self.patch}"
        if self.pre_release:
            version += f"-{self.pre_release}"
        return version
    
    @classmethod
    def from_string(cls, version_str: str) -> 'ModelVersion':
        parts = version_str.split('-')
        version_parts = parts[0].split('.')
        
        return cls(
            major=int(version_parts[0]),
            minor=int(version_parts[1]),
            patch=int(version_parts[2]),
            pre_release=parts[1] if len(parts) > 1 else None
        )
    
    def increment(self, version_type: ModelVersionType) -> 'ModelVersion':
        if version_type == ModelVersionType.MAJOR:
            return ModelVersion(
                self.major + 1, 0, 0, self.pre_release
            )
        elif version_type == ModelVersionType.MINOR:
            return ModelVersion(
                self.major, self.minor + 1, 0, self.pre_release
            )
        else:
            return ModelVersion(
                self.major, self.minor, self.patch + 1, self.pre_release
            )

class ModelVersionManager:
    def __init__(self, registry_path: str):
        self.registry_path = registry_path
        self.registry = self._load_registry()
    
    def _load_registry(self):
        try:
            with open(f"{self.registry_path}/registry.json") as f:
                return json.load(f)
        except FileNotFoundError:
            return {"models": {}}
    
    def _save_registry(self):
        with open(f"{self.registry_path}/registry.json", 'w') as f:
            json.dump(self.registry, f, indent=2, default=str)
    
    def register_model(self, model_name: str, model_path: str,
                       metrics: dict, config: dict,
                       version_type: ModelVersionType = ModelVersionType.MINOR):
        """Register a new model version."""
        
        if model_name not in self.registry["models"]:
            self.registry["models"][model_name] = {
                "versions": [],
                "current_version": None
            }
        
        model_info = self.registry["models"][model_name]
        
        if model_info["current_version"]:
            current = ModelVersion.from_string(model_info["current_version"])
            new_version = current.increment(version_type)
        else:
            new_version = ModelVersion(1, 0, 0)
        
        version_entry = {
            "version": str(new_version),
            "model_path": model_path,
            "metrics": metrics,
            "config": config,
            "created_at": datetime.now().isoformat(),
            "stage": "Development",
            "tags": []
        }
        
        model_info["versions"].append(version_entry)
        model_info["current_version"] = str(new_version)
        
        self._save_registry()
        
        return new_version
    
    def promote_version(self, model_name: str, version: str,
                        stage: str, notes: str = ""):
        """Promote a model version to a new stage."""
        model_info = self.registry["models"][model_name]
        
        for v in model_info["versions"]:
            if v["version"] == version:
                v["stage"] = stage
                v["promoted_at"] = datetime.now().isoformat()
                v["promotion_notes"] = notes
                break
        
        self._save_registry()
    
    def rollback(self, model_name: str, to_version: str):
        """Rollback to a previous version."""
        model_info = self.registry["models"][model_name]
        
        for v in model_info["versions"]:
            if v["version"] == to_version:
                v["stage"] = "Production"
                v["rollback_at"] = datetime.now().isoformat()
                model_info["current_version"] = to_version
                break
        
        self._save_registry()

⚠️

Always version your models with semantic versioning. Breaking changes to input features or output format require a MAJOR version bump to prevent silent failures in production.

Reproducibility Checklist

Ensure complete reproducibility with this checklist:

REPRODUCIBILITY_CHECKLIST = {
    "code": {
        "git_commit": "SHA hash of code",
        "branch": "Git branch name",
        "dependencies": "requirements.txt / poetry.lock",
        "docker_image": "Docker image tag with SHA"
    },
    "data": {
        "dataset_fingerprint": "SHA256 of dataset",
        "data_version": "DVC/Git tag",
        "source_path": "Original data location",
        "row_count": "Number of samples",
        "feature_count": "Number of features"
    },
    "configuration": {
        "hyperparameters": "All model parameters",
        "training_config": "Full training configuration",
        "random_seeds": "All random seeds used",
        "environment": "Python version, GPU driver"
    },
    "metrics": {
        "evaluation_metrics": "All metrics with confidence intervals",
        "training_curves": "Loss/metric trajectories",
        "confusion_matrix": "Full confusion matrix",
        "feature_importance": "Feature importance rankings"
    },
    "artifacts": {
        "model_file": "Serialized model (pickle/joblib/onnx)",
        "preprocessor": "Fitted preprocessing pipeline",
        "feature_selector": "Feature selection mask"
    }
}

def create_reproducibility_package(run_info: dict, output_dir: str):
    """Create a complete reproducibility package."""
    
    package = {
        "metadata": {
            "run_id": run_info["run_id"],
            "experiment": run_info["experiment"],
            "created_at": datetime.now().isoformat(),
            "created_by": run_info.get("user", "unknown")
        },
        "code": {
            "git_commit": run_info.get("git_commit"),
            "git_branch": run_info.get("git_branch"),
            "python_version": sys.version,
        },
        "data": run_info.get("data_info", {}),
        "config": run_info.get("config", {}),
        "metrics": run_info.get("metrics", {}),
        "environment": {
            "platform": platform.platform(),
            "python": sys.version,
            "numpy": np.__version__,
            "pandas": pd.__version__,
            "torch": torch.__version__ if 'torch' in sys.modules else None,
        }
    }
    
    # Save package
    package_path = Path(output_dir) / f"reproducibility_{run_info['run_id']}.json"
    with open(package_path, 'w') as f:
        json.dump(package, f, indent=2, default=str)
    
    return package_path

Experiment Comparison Dashboard

class ExperimentDashboard:
    """Compare experiments across multiple dimensions."""
    
    def __init__(self, tracker):
        self.tracker = tracker
    
    def generate_comparison_report(self, run_ids: list) -> str:
        """Generate HTML comparison report."""
        
        runs_data = []
        for run_id in run_ids:
            run = self.tracker.get_run(run_id)
            runs_data.append({
                'id': run_id,
                'name': run.info.run_name,
                'metrics': run.data.metrics,
                'params': run.data.params,
                'start_time': run.info.start_time,
                'duration': run.info.end_time - run.info.start_time
            })
        
        html = """
        <html>
        <head>
            <title>Experiment Comparison</title>
            <style>
                table { border-collapse: collapse; width: 100%; }
                th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
                th { background-color: #f2f2f2; }
                .best { background-color: #d4edda; }
                .worst { background-color: #f8d7da; }
            </style>
        </head>
        <body>
            <h1>Experiment Comparison</h1>
            <table>
                <tr>
                    <th>Run</th>
                    <th>Duration</th>
        """
        
        # Add metric columns
        if runs_data:
            for metric in runs_data[0]['metrics'].keys():
                html += f"<th>{metric}</th>"
        
        html += "</tr>"
        
        # Add rows
        for run in runs_data:
            html += f"""
            <tr>
                <td>{run['name']}</td>
                <td>{run['duration']:.1f}s</td>
            """
            
            values = list(run['metrics'].values())
            if values:
                max_val = max(values)
                min_val = min(values)
                
                for metric, value in run['metrics'].items():
                    css_class = ""
                    if value == max_val:
                        css_class = ' class="best"'
                    elif value == min_val:
                        css_class = ' class="worst"'
                    html += f"<td{css_class}>{value:.4f}</td>"
            
            html += "</tr>"
        
        html += """
            </table>
        </body>
        </html>
        """
        
        return html

ℹ️

At scale, experiment tracking systems can generate millions of runs. Use tags, custom metadata, and automated filtering to manage this volume effectively. Consider implementing experiment analytics pipelines for trend analysis.

Summary

Experiment tracking is foundational to ML reproducibility:

  1. MLflow: Open-source, multi-framework support, model registry
  2. W&B: Interactive visualizations, real-time collaboration, sweeps
  3. Neptune: Metadata store, experiment comparison, team management
  4. DVC: Git-based data versioning, pipeline management

Choose based on your team size, infrastructure, and collaboration needs. Always ensure reproducibility with complete metadata logging.

Advertisement