Partitioning Strategies
Difficulty: Senior Level | Companies: Databricks, Netflix, Uber, Apple, Airbnb
Why Partitioning Matters
Partitioning determines how data is distributed across executors. Poor partitioning leads to data skew, excessive shuffle, and underutilized resources.
Reading Partitioned Data
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder \
.appName("PartitioningStrategies") \
.config("spark.sql.shuffle.partitions", "200") \
.config("spark.default.parallelism", "200") \
.getOrCreate()
# Reading Hive-style partitioned data
df = spark.read \
.partitionBy("year", "month") \
.parquet("hdfs://data/events/")
# Spark reads only relevant partitions (partition pruning)
filtered = df.filter((F.col("year") == 2024) & (F.col("month") == 6))
print(f"Partitions scanned: {filtered.rdd.getNumPartitions()}")
# Check partition count
print(f"Total partitions: {df.rdd.getNumPartitions()}")
βΉοΈ
Interview Insight: Partition pruning is one of Spark's most powerful optimizations. Always partition large datasets by commonly filtered columns (date, region, etc.) to enable partition pruning.
Hash Partitioning
# Hash partitioning distributes data based on hash of partition key
df = spark.read.parquet("hdfs://data/transactions")
# Hash partition by key
hash_partitioned = df.repartition(200, "user_id")
# Verify partition distribution
def check_partition_distribution(df, partition_col):
return df \
.withColumn("partition_id", F.spark_partition_id()) \
.groupBy("partition_id") \
.agg(F.count("*").alias("row_count")) \
.orderBy("partition_id")
check_partition_distribution(hash_partitioned, "user_id").show(20)
# For RDDs, use custom partitioner
rdd = df.select("user_id", "amount").rdd
# Custom hash partitioner
def user_partitioner(key):
return hash(key) % 200
partitioned_rdd = rdd.partitionBy(200, user_partitioner)
Range Partitioning
# Range partitioning sorts data by key and distributes evenly
df = spark.read.parquet("hdfs://data/transactions")
# Range partition by numeric key
range_partitioned = df.repartitionByRange(200, "amount")
# Verify range distribution
range_partitioned \
.withColumn("partition_id", F.spark_partition_id()) \
.groupBy("partition_id") \
.agg(
F.min("amount").alias("min_amount"),
F.max("amount").alias("max_amount"),
F.count("*").alias("count")
) \
.orderBy("partition_id") \
.show(20)
# Multiple columns for range partitioning
df.repartitionByRange(200, "region", "amount")
Coalesce vs Repartition
# Coalesce: Reduce partitions without full shuffle (more efficient)
df = spark.read.parquet("hdfs://data/large-dataset")
print(f"Original partitions: {df.rdd.getNumPartitions()}")
# Coalesce down (no shuffle, just combines partitions)
coalesced = df.coalesce(50)
print(f"After coalesce: {coalesced.rdd.getNumPartitions()}")
# Repartition: Full shuffle, can increase or decrease
# Use when you need to redistribute data evenly
repartitioned = df.repartition(50)
print(f"After repartition: {repartitioned.rdd.getNumPartitions()}")
# Coalesce is O(n), repartition is O(n * log(n))
# Always prefer coalesce when reducing partitions
# Repartition by column (forces shuffle for even distribution)
df_by_region = df.repartition(50, "region")
# Coalesce cannot increase partitions
# This will NOT add more partitions
df.coalesce(500) # Still 50 partitions if input has fewer
# Repartition can increase partitions
df.repartition(500) # Will create 500 partitions
β οΈ
Warning: Use coalesce when reducing partitions (no shuffle). Use repartition when increasing partitions or when you need even distribution by key.
Custom Partitioners
from pyspark import Partitioner
class RegionPartitioner(Partitioner):
def __init__(self, numPartitions, region_map):
super().__init__(numPartitions)
self.region_map = region_map
def getPartition(self, key):
region = key[0] # Assuming key is (region, ...)
return self.region_map.get(region, 0) % self.numPartitions
# Create region mapping
regions = ["US", "EU", "APAC", "LATAM"]
region_map = {r: i for i, r in enumerate(regions)}
# Apply custom partitioner
rdd = spark.read.parquet("hdfs://data/sales") \
.select("region", "amount") \
.rdd \
.map(lambda row: ((row[0], row[1]), 1))
custom_partitioned = rdd.partitionBy(4, RegionPartitioner(4, region_map))
# Verify distribution
custom_partitioned \
.mapPartitionsWithIndex(lambda idx, it: [(idx, sum(1 for _ in it))]) \
.collect()
Partitioning for Joins
# Optimal join partitioning: both sides partitioned by join key
left = spark.read.parquet("hdfs://data/orders")
right = spark.read.parquet("hdfs://data/customers")
# Repartition both sides by join key
left_partitioned = left.repartition(200, "customer_id")
right_partitioned = right.repartition(200, "customer_id")
# Now join is efficient - no shuffle during join
result = left_partitioned.join(right_partitioned, "customer_id")
result.explain() # Check for minimal shuffle
# For broadcast joins, no partitioning needed
from pyspark.sql.functions import broadcast
result = left.join(broadcast(right), "customer_id") # No shuffle
Partition Pruning Optimization
# Partition your data by frequently filtered columns
df = spark.read.parquet("hdfs://data/events")
# Write with partitioning for future queries
df.write \
.partitionBy("event_date", "event_type") \
.mode("overwrite") \
.parquet("hdfs://data/events_partitioned/")
# Now queries benefit from partition pruning
query = spark.read.parquet("hdfs://data/events_partitioned/") \
.filter((F.col("event_date") == "2024-06-15") &
(F.col("event_type") == "click"))
# Only reads relevant partitions, not full dataset
query.explain() # Check for PartitionFilters in the plan
# Check partition statistics
spark.sql("DESCRIBE EXTENDED events_partitioned").show(50)
βΉοΈ
Pro Tip: Avoid partitioning by high-cardinality columns (like user_id). This creates too many small files. Instead, partition by low-to-medium cardinality columns (date, region, category).
Dynamic Partition Pruning (Spark 3.x)
# Spark 3.x automatically optimizes partition pruning in joins
dimensions = spark.read.parquet("hdfs://data/dimensions") \
.partitionBy("date") # Partitioned dimension table
facts = spark.read.parquet("hdfs://data/facts") # Fact table
# Spark 3.x applies dynamic partition pruning
result = facts.join(dimensions, "key") \
.filter(F.col("date") == "2024-06-15")
# Check physical plan for DynamicPruningExpression
result.explain(mode="formatted")
Monitoring Partition Health
# Monitor partition sizes and distribution
df = spark.read.parquet("hdfs://data/events")
partition_stats = df \
.withColumn("partition_id", F.spark_partition_id()) \
.groupBy("partition_id") \
.agg(
F.count("*").alias("row_count"),
F.sum(F.col("amount").isNotNull().cast("long")).alias("non_null_amount"),
F.countDistinct("user_id").alias("unique_users")
) \
.orderBy("partition_id")
partition_stats.show(50)
# Check for skewed partitions
skewed = partition_stats.filter(F.col("row_count") > 1000000)
if skewed.count() > 0:
print("WARNING: Skewed partitions detected!")
skewed.show()
# Get total data size per partition using Hadoop FileSystem
def get_partition_sizes(path):
sc = spark.sparkContext
hadoop_conf = sc._jsc.hadoopConfiguration()
fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(hadoop_conf)
status = fs.listStatus(sc._jvm.org.apache.hadoop.fs.Path(path))
return [(s.getPath().getName(), s.getLen()) for s in status]
sizes = get_partition_sizes("hdfs://data/events/")
print(f"Total partitions: {len(sizes)}")
print(f"Total size: {sum(s[1] for s in sizes) / 1024**3:.2f} GB")
βΉοΈ
Key Takeaway: Choose partitioning based on your query patterns. Use hash partitioning for even distribution, range partitioning for ordered data, and always partition large datasets by commonly filtered columns.
Follow-Up Questions
- How does Spark decide between sort-merge join and broadcast join based on partitioning?
- Explain the difference between
spark.sql.shuffle.partitionsandspark.default.parallelism. - When would you use
repartitionByRangeoverrepartition? - How does partitioning interact with Delta Lake's Z-ordering?
- Describe strategies for repartitioning a live table without downtime.