Experiment Tracking with MLflow
Reproducibility in ML requires tracking every experiment, parameter, and metric. MLflow provides a complete platform for experiment management and model lifecycle.
MLflow Components
Experiment Tracking Basics
import mlflow
import mlflow.sklearn
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
# Set tracking URI (local or remote)
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("iris-classification")
# Load data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Train and log
with mlflow.start_run(run_name="rf-baseline"):
# Parameters
params = {
"n_estimators": 100,
"max_depth": 5,
"min_samples_split": 3,
"random_state": 42
}
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='weighted')
# Log metrics
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
# Log model with signature
from mlflow.models import infer_signature
signature = infer_signature(X_train, model.predict(X_train))
mlflow.sklearn.log_model(
model,
artifact_path="model",
signature=signature,
input_example=X_train[:5],
registered_model_name="iris-rf"
)
# Log additional artifacts
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.barh(range(4), model.feature_importances_)
ax.set_yticklabels(['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
plt.savefig("feature_importance.png")
mlflow.log_artifact("feature_importance.png")
print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")
Advanced Tracking Patterns
import mlflow
from functools import wraps
import time
def track_experiment(func):
"""Decorator to automatically track experiments"""
@wraps(func)
def wrapper(*args, **kwargs):
with mlflow.start_run(nested=True):
start_time = time.time()
result = func(*args, **kwargs)
mlflow.log_metric("total_time", time.time() - start_time)
return result
return wrapper
# Track nested runs for hyperparameter sweeps
def run_hyperparameter_search(X_train, X_test, y_train, y_test):
with mlflow.start_run(run_name="hyperparameter-sweep"):
best_score = 0
best_params = {}
param_grid = {
"n_estimators": [50, 100, 200],
"max_depth": [3, 5, 10, None],
"min_samples_split": [2, 5, 10]
}
from itertools import product
for params in product(*param_grid.values()):
param_dict = dict(zip(param_grid.keys(), params))
with mlflow.start_run(nested=True):
mlflow.log_params(param_dict)
model = RandomForestClassifier(**param_dict, random_state=42)
model.fit(X_train, y_train)
score = accuracy_score(y_test, model.predict(X_test))
mlflow.log_metric("accuracy", score)
if score > best_score:
best_score = score
best_params = param_dict
mlflow.log_params({f"best_{k}": v for k, v in best_params.items()})
mlflow.log_metric("best_accuracy", best_score)
return best_params, best_score
# Log datasets for reproducibility
mlflow.log_input(
mlflow.data.from_pandas(
pd.DataFrame(X_train, columns=feature_names),
source="sklearn-iris",
name="training_data"
),
context="training"
)
Model Registry
The model registry provides versioning, stage transitions, and governance.
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register a model
model_uri = "runs:/<run_id>/model"
mlflow.register_model(model_uri, "production-classifier")
# Transition model stage
client.transition_model_version_stage(
name="production-classifier",
version=1,
stage="Staging"
)
# Add description and tags
client.update_model_version(
name="production-classifier",
version=1,
description="Random Forest classifier trained on iris dataset v1"
)
client.set_model_version_tag(
name="production-classifier",
version=1,
key="validation_status",
value="passed"
)
# Promote to production
client.transition_model_version_stage(
name="production-classifier",
version=1,
stage="Production"
)
# Load model from registry
model = mlflow.sklearn.load_model("models:/production-classifier/Production")
# Get model version details
versions = client.search_model_versions("name='production-classifier'")
for version in versions:
print(f"Version {version.version}: {version.current_stage}")
Custom Model Flavors
import mlflow
import mlflow.pyfunc
import pandas as pd
import numpy as np
class CustomModel(mlflow.pyfunc.PythonModel):
def __init__(self, model, preprocessor, postprocessor=None):
self.model = model
self.preprocessor = preprocessor
self.postprocessor = postprocessor
def predict(self, context, model_input):
processed = self.preprocessor.transform(model_input)
predictions = self.model.predict(processed)
if self.postprocessor:
predictions = self.postprocessor(predictions)
return pd.DataFrame({"prediction": predictions})
# Log custom model
preprocessor = StandardScaler()
preprocessor.fit(X_train)
custom_model = CustomModel(model, preprocessor)
mlflow.pyfunc.log_model(
artifact_path="custom_model",
python_model=custom_model,
conda_env={
"dependencies": [
"python=3.9",
"scikit-learn=1.2.0",
"pandas=1.5.0",
"numpy=1.24.0"
]
},
registered_model_name="custom-ensemble"
)
# Load and use
loaded_model = mlflow.pyfunc.load_model("models:/custom-ensemble/Production")
predictions = loaded_model.predict(pd.DataFrame(X_test, columns=feature_names))
Deployment with MLflow
import mlflow.pyfunc
from fastapi import FastAPI
import uvicorn
import pandas as pd
app = FastAPI()
model = mlflow.pyfunc.load_model("models:/production-classifier/Production")
@app.post("/predict")
async def predict(features: dict):
df = pd.DataFrame([features])
prediction = model.predict(df)
return {"prediction": prediction.tolist()}
@app.post("/predict/batch")
async def predict_batch(data: list):
df = pd.DataFrame(data)
predictions = model.predict(df)
return {"predictions": predictions.tolist()}
@app.get("/model/info")
async def model_info():
return {
"model_name": "production-classifier",
"version": model.metadata.run_id,
"features": model.metadata.signature.inputs
}
# Start server: uvicorn app:app --host 0.0.0.0 --port 8000
Best Practices
- Use nested runs for hyperparameter searches to organize experiments
- Log everything β parameters, metrics, artifacts, and datasets
- Tag runs with purpose, author, and environment for searchability
- Use model registry stages for promotion workflows
- Set up model monitoring alongside tracking for production models