Data Versioning
Data versioning is the practice of tracking changes to datasets over time, enabling reproducibility, collaboration, and rollback capabilities in machine learning workflows.
Why Data Versioning Matters
- Reproducibility: Recreate exact dataset states
- Collaboration: Multiple team members work with consistent data
- Audit Trail: Track data changes and transformations
- Debugging: Identify when data issues were introduced
- Compliance: Meet regulatory requirements for data governance
Data Versioning Architecture
DVC (Data Version Control)
Basic DVC Usage
# Initialize DVC in your project
dvc init
# Track a dataset
dvc add data/training_data.csv
# Commit the DVC file
git add data/training_data.csv.dvc
git commit -m "Add training dataset v1"
# Push to remote storage
dvc push
Python DVC Integration
import dvc.api
import pandas as pd
from dvc.repo import Repo
class DVCDataManager:
def __init__(self, repo_path):
self.repo = Repo(repo_path)
def get_data(self, path, rev=None):
"""Get data from DVC repository"""
with dvc.api.open(path, rev=rev) as f:
return pd.read_csv(f)
def track_file(self, file_path):
"""Track file with DVC"""
self.repo.add(file_path)
self.repo.push()
def get_data_version(self, file_path):
"""Get current version of data"""
return self.repo scm.get_rev()
def list_versions(self, file_path):
"""List all versions of a file"""
return self.repo.brancher.runs(file_path)
def compare_versions(self, path, rev1, rev2):
"""Compare data between versions"""
data1 = self.get_data(path, rev1)
data2 = self.get_data(path, rev2)
return {
"shape_diff": (data1.shape, data2.shape),
"column_diff": set(data1.columns) - set(data2.columns),
"row_diff": len(data1) - len(data2)
}
Delta Lake Implementation
Delta Lake Versioning
from delta.tables import DeltaTable
from pyspark.sql import SparkSession
class DeltaLakeVersionManager:
def __init__(self, spark, storage_path):
self.spark = spark
self.storage_path = storage_path
def create_table(self, data, table_name):
"""Create Delta table from DataFrame"""
df = self.spark.createDataFrame(data)
df.write.format("delta").saveAsTable(table_name)
def upsert_data(self, data, table_name, key_columns):
"""Upsert data with versioning"""
delta_table = DeltaTable.forName(self.spark, table_name)
df = self.spark.createDataFrame(data)
delta_table.alias("target").merge(
df.alias("source"),
" AND ".join([f"target.{col} = source.{col}" for col in key_columns])
).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
def get_version(self, table_name, version=None):
"""Get specific version of table"""
if version:
return self.spark.read.format("delta").option("versionAsOf", version).table(table_name)
else:
return self.spark.read.format("delta").table(table_name)
def get_table_history(self, table_name):
"""Get version history of table"""
delta_table = DeltaTable.forName(self.spark, table_name)
history = delta_table.history()
return history.collect()
def rollback_to_version(self, table_name, version):
"""Rollback table to specific version"""
delta_table = DeltaTable.forName(self.spark, table_name)
delta_table.restoreToVersion(version)
Delta Lake Time Travel
class DeltaTimeTravel:
def __init__(self, spark):
self.spark = spark
def query_at_timestamp(self, table_path, timestamp):
"""Query data at specific timestamp"""
return self.spark.read.format("delta") \
.option("timestampAsOf", timestamp) \
.load(table_path)
def query_at_version(self, table_path, version):
"""Query data at specific version"""
return self.spark.read.format("delta") \
.option("versionAsOf", version) \
.load(table_path)
def compare_versions(self, table_path, version1, version2):
"""Compare two versions"""
df1 = self.query_at_version(table_path, version1)
df2 = self.query_at_version(table_path, version2)
# Find differences
diff1 = df1.exceptAll(df2)
diff2 = df2.exceptAll(df1)
return {
"added_rows": diff2.count(),
"removed_rows": diff1.count(),
"common_rows": df1.intersect(df2).count()
}
Data Lineage
Lineage Tracking
from dataclasses import dataclass
from typing import List, Dict, Any
import networkx as nx
@dataclass
class DataLineageNode:
id: str
name: str
type: str # 'dataset', 'transformation', 'model'
metadata: Dict[str, Any]
created_at: str
@dataclass
class DataLineageEdge:
source: str
target: str
transformation: str
parameters: Dict[str, Any]
class DataLineageTracker:
def __init__(self):
self.graph = nx.DiGraph()
self.nodes = {}
def add_dataset(self, dataset_id, name, metadata):
"""Add dataset to lineage"""
node = DataLineageNode(
id=dataset_id,
name=name,
type="dataset",
metadata=metadata,
created_at=datetime.now().isoformat()
)
self.nodes[dataset_id] = node
self.graph.add_node(dataset_id, **node.__dict__)
def add_transformation(self, transform_id, name, metadata):
"""Add transformation to lineage"""
node = DataLineageNode(
id=transform_id,
name=name,
type="transformation",
metadata=metadata,
created_at=datetime.now().isoformat()
)
self.nodes[transform_id] = node
self.graph.add_node(transform_id, **node.__dict__)
def add_edge(self, source_id, target_id, transformation, parameters=None):
"""Add edge to lineage"""
edge = DataLineageEdge(
source=source_id,
target=target_id,
transformation=transformation,
parameters=parameters or {}
)
self.graph.add_edge(source_id, target_id, **edge.__dict__)
def get_upstream(self, dataset_id):
"""Get all upstream dependencies"""
return list(nx.ancestors(self.graph, dataset_id))
def get_downstream(self, dataset_id):
"""Get all downstream dependents"""
return list(nx.descendants(self.graph, dataset_id))
def visualize_lineage(self, dataset_id):
"""Create visualization of lineage"""
import matplotlib.pyplot as plt
# Get subgraph for dataset
upstream = self.get_upstream(dataset_id)
downstream = self.get_downstream(dataset_id)
nodes = upstream + downstream + [dataset_id]
subgraph = self.graph.subgraph(nodes)
# Plot
pos = nx.spring_layout(subgraph)
nx.draw(subgraph, pos, with_labels=True, node_color='lightblue',
node_size=2000, font_size=10, font_weight='bold')
plt.title(f"Data Lineage for {dataset_id}")
plt.show()
Mathematical Foundation
Data Drift Measurement
The Jensen-Shannon divergence for measuring data drift:
Jensen-Shannon Divergence
Where ( M = \frac{1}{2}(P + Q) ) and ( D_{KL} ) is the Kullback-Leibler divergence.
Data Quality Score
A composite score for data quality:
Data Quality Score
Where ( w_1 + w_2 + w_3 + w_4 = 1 ) are weights for each quality dimension.
Version Similarity
Jaccard similarity between dataset versions:
Dataset Similarity
Where ( A ) and ( B ) are sets of records from different versions.
Data Quality Validation
Validation Framework
import pandas as pd
from typing import Dict, List, Any
import json
class DataValidator:
def __init__(self):
self.rules = {}
self.results = {}
def add_rule(self, rule_name, rule_fn, description=""):
"""Add validation rule"""
self.rules[rule_name] = {
"function": rule_fn,
"description": description
}
def validate(self, dataset: pd.DataFrame) -> Dict[str, Any]:
"""Validate dataset against all rules"""
results = {}
for rule_name, rule_info in self.rules.items():
try:
passed = rule_info["function"](dataset)
results[rule_name] = {
"passed": passed,
"description": rule_info["description"]
}
except Exception as e:
results[rule_name] = {
"passed": False,
"error": str(e),
"description": rule_info["description"]
}
self.results = results
return results
def generate_report(self):
"""Generate validation report"""
total_rules = len(self.results)
passed_rules = sum(1 for r in self.results.values() if r["passed"])
report = {
"total_rules": total_rules,
"passed_rules": passed_rules,
"failed_rules": total_rules - passed_rules,
"success_rate": passed_rules / total_rules if total_rules > 0 else 0,
"details": self.results
}
return report
# Example usage
validator = DataValidator()
# Add validation rules
validator.add_rule(
"no_nulls",
lambda df: df.isnull().sum().sum() == 0,
"Dataset should have no null values"
)
validator.add_rule(
"valid_dates",
lambda df: pd.to_datetime(df['date']).notna().all(),
"All dates should be valid"
)
validator.add_rule(
"positive_values",
lambda df: (df['amount'] > 0).all(),
"All amounts should be positive"
)
# Validate dataset
results = validator.validate(df)
report = validator.generate_report()
Best Practices
1. Immutable Data Storage
- Store raw data in append-only format
- Never modify original data files
- Use content-addressable storage
2. Metadata Documentation
- Document data sources
- Record transformation steps
- Track data quality metrics
3. Automated Validation
- Implement data quality checks
- Set up automated validation pipelines
- Create alerting for data issues
4. Access Control
- Implement role-based access
- Log data access patterns
- Maintain audit trails
Common Challenges
| Challenge | Description | Solution |
|---|---|---|
| Storage Costs | Large datasets are expensive to version | Incremental storage, compression |
| Performance | Versioning adds overhead | Async processing, caching |
| Complexity | Managing many versions | Clear naming conventions |
| Collaboration | Team coordination | Clear workflows, documentation |
Summary
Data versioning is essential for reproducible machine learning. By implementing proper version control with tools like DVC or Delta Lake, organizations can ensure data consistency, enable collaboration, and maintain compliance with data governance requirements.