Spark Performance Tuning: Shuffle, Broadcast, AQS
Optimizing Spark jobs for production workloads
Interview Question
"Your Spark job takes 2 hours to process 5TB of data. The job joins a 4TB fact table with a 100GB dimension table, then aggregates by 50 columns. Identify the top 3 performance bottlenecks and explain how to fix each one. Include specific Spark configurations and code changes."
Difficulty: Hard | Frequently asked at Databricks, Netflix, Uber, Airbnb
Theoretical Foundation
Spark Execution Model
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Spark Application Architecture β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Driver Program β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β SparkContext / SparkSession β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β β β DAG Scheduler β β β
β β β - Converts RDD lineage to DAG β β β
β β β - Optimizes stages β β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β β β Task Scheduler β β β
β β β - Assigns tasks to executors β β β
β β β - Handles data locality β β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Executor 1 Executor 2 Executor 3 β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β Task 1 β β Task 2 β β Task 3 β β
β β Task 2 β β Task 3 β β Task 4 β β
β β Cache β β Cache β β Cache β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Shuffle: The #1 Performance Killer
A shuffle occurs when data needs to be redistributed across partitions. It's the most expensive operation in Spark.
Shuffle operations:
join()(when not broadcast)groupByKey()reduceByKey()distinct()repartition()sort()
What happens during a shuffle:
Shuffle costs:
- I/O: Write intermediate data to local disk
- Network: Transfer data across nodes
- Memory: Deserialize and buffer data
- GC: Memory pressure from large shuffle buffers
Shuffle metrics:
Broadcast Joins
A broadcast join sends the smaller table to all executors, avoiding shuffle.
When to broadcast:
- One table is significantly smaller (< 10GB typically)
- Join key has high cardinality
- Shuffle is too expensive
Spark configuration:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "1073741824") # 1GB
spark.conf.set("spark.sql.broadcastTimeout", "300") # 5 minutes
Adaptive Query Execution (AQE)
AQE dynamically optimizes queries at runtime based on actual data statistics.
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Adaptive Query Execution (AQE) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β 1. Coalesce Shuffle Partitions β
β - Merge small partitions after shuffle β
β - Reduces task overhead β
β β
β 2. Switch Join Strategies β
β - Convert sort-merge join to broadcast join β
β - Based on actual table sizes after filtering β
β β
β 3. Optimize Skewed Joins β
β - Detect skewed partitions β
β - Split large partitions β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
AQE configurations:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
Data Skew
Data skew occurs when some partitions have much more data than others.
Detection:
# Check partition sizes after shuffle
df.rdd.mapPartitionsWithIndex(lambda idx, it: [(idx, sum(1 for _ in it))]) \
.collect()
Solutions:
- Salting (add random prefix to skew key)
- Isolate skewed keys
- Use AQE skew join optimization
- Pre-aggregate before join
Memory Management
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Executor Memory Layout β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Total Executor Memory (e.g., 16GB) β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Reserved Memory (300MB) β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ β
β β User Memory (40% = 6.4GB) β β
β β - User data structures β β
β β - UDFs β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ β
β β Spark Memory (60% = 9.6GB) β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β β β Execution Memory (50% = 4.8GB) β β β
β β β - Shuffles, joins, sorts, aggregations β β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ€ β β
β β β Storage Memory (50% = 4.8GB) β β β
β β β - Cached data, broadcast variables β β β
β β βββββββββββββββββββββββββββββββββββββββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Memory configurations:
spark.conf.set("spark.executor.memory", "16g")
spark.conf.set("spark.executor.memoryFraction", "0.6") # 60% for Spark
spark.conf.set("spark.shuffle.memoryFraction", "0.2") # 20% for shuffle
spark.conf.set("spark.storage.memoryFraction", "0.4") # 40% for storage
Code Implementation
Fix 1: Broadcast Join for Small Table
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder \
.appName("SparkOptimization") \
.getOrCreate()
# ============================================================
# BEFORE OPTIMIZATION: Shuffle Join (Slow)
# ============================================================
# Read tables
fact_sales = spark.read.parquet("s3://data-lake/fact_sales/") # 4TB
dim_product = spark.read.parquet("s3://data-lake/dim_product/") # 100GB
# This triggers a shuffle join (slow)
result = fact_sales.join(dim_product, "product_id")
# Spark plan shows:
# == Physical Plan ==
// SortMergeJoin [product_id], [product_id]
// :- Scan parquet fact_sales
// β:- Scan parquet dim_product
// (Both tables are shuffled)
# ============================================================
# AFTER OPTIMIZATION: Broadcast Join (Fast)
# ============================================================
# Option 1: Auto-broadcast (set threshold)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "1073741824") # 1GB
result_optimized = fact_sales.join(dim_product, "product_id")
# Now Spark plan shows:
// == Physical Plan ==
// BroadcastHashJoin [product_id], [product_id], BuildRight
// :- Scan parquet fact_sales
// β:- BroadcastExchange HashedRelationBroadcastMode
// β:- Scan parquet dim_product
// (dim_product is broadcast, no shuffle)
# Option 2: Explicit broadcast hint
from pyspark.sql.functions import broadcast
result_hint = fact_sales.join(broadcast(dim_product), "product_id")
# Option 3: Persist broadcast table
dim_product.persist()
dim_product.count() # Materialize in memory
Fix 2: Optimize Shuffle Partitions
# ============================================================
# OPTIMIZE SHUFFLE PARTITIONS
# ============================================================
# Default: 200 partitions (too many for 4TB data)
spark.conf.set("spark.sql.shuffle.partitions", "200")
# Rule of thumb: 200MB per partition
# 4TB / 200MB = 20,000 partitions (too many)
# Optimal: 4TB / 128MB = 32,000 partitions
# Set optimal partitions
spark.conf.set("spark.sql.shuffle.partitions", "32000")
# Or use AQE to auto-optimize
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.minPartitionSize", "64MB")
spark.conf.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "32000")
# ============================================================
# AQE OPTIMIZATION
# ============================================================
# With AQE enabled, Spark will:
# 1. Start with 32000 partitions
# 2. After shuffle, coalesce small partitions
# 3. Final partitions might be 10000 (if many are small)
# Enable AQE skew join optimization
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
Fix 3: Handle Data Skew
# ============================================================
# HANDLE DATA SKEW WITH SALT
# ============================================================
# Problem: Join on "category" column is skewed
# Some categories have 100x more data
# Solution 1: Salt the skewed key
def salt_join(df_large, df_small, join_key, n_salt=10):
"""
Salt a join to handle skew.
Adds random prefix to large table, replicates small table.
"""
from pyspark.sql.functions import rand, array, lit, explode
# Add salt to large table
df_large_salted = df_large.withColumn(
"salt", (rand() * n_salt).cast("int")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Replicate small table for each salt value
df_small_salted = df_small.crossJoin(
F.range(0, n_salt).withColumnRenamed("id", "salt")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Join on salted key
return df_large_salted.join(df_small_salted, "salted_key")
# Solution 2: Isolate skewed keys
def isolate_skew_join(df_large, df_small, join_key, skew_threshold=1000000):
"""
Handle skew by isolating skewed keys.
"""
# Find skewed keys
skewed_keys = df_large.groupBy(join_key) \
.count() \
.filter(F.col("count") > skew_threshold) \
.select(join_key)
# Split data
df_large_skewed = df_large.join(skewed_keys, join_key)
df_large_normal = df_large.join(skewed_keys, join_key, "left_anti")
# Join normal data (no skew)
result_normal = df_large_normal.join(df_small, join_key)
# Join skewed data with salting
result_skewed = salt_join(df_large_skewed, df_small, join_key)
# Union results
return result_normal.unionByName(result_skewed, allowMissingColumns=True)
# Solution 3: Pre-aggregate before join
def preaggregate_join(df_large, df_small, join_key, agg_col):
"""
Pre-aggregate large table before joining.
"""
# Pre-aggregate large table
df_large_agg = df_large.groupBy(join_key) \
.agg(F.sum(agg_col).alias(f"sum_{agg_col}"))
# Join with small table
result = df_large_agg.join(df_small, join_key)
# Explode if needed (if join is many-to-many)
return result
Fix 4: Optimize Serialization
# ============================================================
# OPTIMIZE SERIALIZATION
# ============================================================
# Use Kryo serialization (faster than Java serialization)
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
spark.conf.set("spark.kryoserializer.buffer.max", "512m")
# Register custom classes for better performance
spark.sparkContext._conf.set("spark.kryo.registrationRequired", "false")
# Use off-heap memory for large objects
spark.conf.set("spark.memory.offHeap.enabled", "true")
spark.conf.set("spark.memory.offHeap.size", "4g")
Fix 5: Cache Strategically
# ============================================================
# STRATEGIC CACHING
# ============================================================
# Cache frequently accessed data
dim_product = spark.read.parquet("s3://data-lake/dim_product/")
dim_product.cache() # or .persist(StorageLevel.MEMORY_AND_DISK)
# Materialize cache
dim_product.count()
# Use cache for multiple operations
result1 = fact_sales.groupBy("product_id").agg(F.sum("amount"))
result2 = fact_sales.filter(F.col("amount") > 100).join(dim_product, "product_id")
# Unpersist when done
dim_product.unpersist()
# Check cache status
print(f"StorageLevel: {dim_product.storageLevel}")
print(f"Cache Size: {dim_product.cacheInfo.sizeInMemory}")
Monitoring and Profiling
# ============================================================
# MONITORING AND PROFILING
# ============================================================
# Enable Spark UI
spark.conf.set("spark.ui.enabled", "true")
spark.conf.set("spark.eventLog.enabled", "true")
spark.conf.set("spark.eventLog.dir", "s3://spark-logs/event-logs/")
# Monitor shuffle metrics
def monitor_shuffle(df):
"""Monitor shuffle metrics"""
# Count shuffle partitions
rdd = df.rdd
num_partitions = rdd.getNumPartitions()
print(f"Number of partitions: {num_partitions}")
# Get partition sizes
partition_sizes = rdd.mapPartitionsWithIndex(
lambda idx, it: [(idx, sum(1 for _ in it))]
).collect()
print(f"Partition sizes: {partition_sizes}")
# Check for skew
sizes = [s for _, s in partition_sizes]
avg_size = sum(sizes) / len(sizes)
max_size = max(sizes)
skew_ratio = max_size / avg_size
print(f"Average partition size: {avg_size}")
print(f"Max partition size: {max_size}")
print(f"Skew ratio: {skew_ratio}")
if skew_ratio > 10:
print("WARNING: High skew detected!")
# Profile a query
def profile_query(query):
"""Profile a Spark query"""
# Get query plan
query.explain(mode="extended")
# Count shuffle stages
plan = query._jdf.queryExecution().executedPlan()
# Get metrics
metrics = query._jdf.queryExecution().executedPlan().metrics()
print("Query Metrics:")
for key, metric in metrics.items():
print(f" {key}: {metric.value}")
Complete Optimization Example
# ============================================================
# COMPLETE OPTIMIZATION EXAMPLE
# ============================================================
def optimize_spark_job():
"""Optimize the original slow job"""
# 1. Enable AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# 2. Optimize shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", "32000")
# 3. Broadcast small table
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "1073741824")
# 4. Use Kryo serialization
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
# 5. Read data with optimal partitioning
fact_sales = spark.read.parquet("s3://data-lake/fact_sales/") \
.repartition(32000, "product_id") # Pre-partition for join
dim_product = spark.read.parquet("s3://data-lake/dim_product/")
# 6. Broadcast join (auto-detected by AQE)
joined = fact_sales.join(dim_product, "product_id")
# 7. Pre-aggregate if possible
# Instead of: joined.groupBy("col1", "col2", "col3", ...).agg(...)
# Use: joined.groupBy("product_id").agg(...).groupBy("col1", ...).agg(...)
# 8. Cache intermediate results if reused
joined.cache()
# 9. Execute with monitoring
result = joined \
.groupBy("category", "region", "year", "quarter") \
.agg(
F.sum("amount").alias("total_amount"),
F.count("*").alias("transaction_count")
)
# Write result
result.write.format("delta") \
.mode("overwrite") \
.save("s3://data-warehouse/sales_summary/")
# 10. Unpersist
joined.unpersist()
optimize_spark_job()
π‘
Production Tip: Always check the Spark UI for shuffle metrics. If you see high shuffle read/write, focus on: (1) broadcasting small tables, (2) reducing shuffle partitions, (3) handling skew. The Spark UI is your best friend for performance tuning.
Common Follow-Up Questions
Q1: When should you use repartition() vs coalesce()?
repartition(n): Full shuffle, creates n partitions. Use when increasing partitions or need even distribution.coalesce(n): No shuffle, merges partitions. Use when decreasing partitions.
Q2: How do you debug a slow Spark job?
- Check Spark UI for shuffle metrics
- Look for data skew in partition sizes
- Check if broadcast join is possible
- Verify serialization is efficient
- Check memory usage and GC pauses
Q3: What's the difference between persist() and cache()?
cache(): Same aspersist(StorageLevel.MEMORY_AND_DISK)persist(level): More control over storage level
Use persist(StorageLevel.MEMORY_AND_DISK_SER) for large datasets.
Q4: How do you optimize Spark for streaming?
# Streaming optimizations
spark.conf.set("spark.sql.streaming.checkpointLocation", "s3://checkpoints/")
spark.conf.set("spark.sql.streaming.stateStore.providerClass",
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider")
spark.conf.set("spark.sql.streaming.stateStore.minPartitions", "32")
β οΈ
Critical Consideration: Spark performance tuning is iterative. Always: (1) measure before optimizing, (2) make one change at a time, (3) verify improvement. Don't optimize blindlyβlet the data guide you.
Company-Specific Tips
Databricks Interview Tips
- Discuss Photon engine and vectorized execution
- Mention Delta Lake optimizations (Z-Order, Optimize)
- Be ready to explain AQE in detail
- Talk about Auto Optimize and Auto Compaction
Netflix Interview Tips
- Focus on large-scale data processing (petabytes)
- Discuss cost optimization with spot instances
- Mention multi-tenant Spark clusters
- Talk about streaming + batch unification
Uber Interview Tips
- Discuss real-time analytics with Spark Streaming
- Explain geospatial data processing optimizations
- Mention ML pipelines with Spark MLlib
- Talk about resource management with Kubernetes
βΉοΈ
Final Takeaway: Spark optimization is about understanding the tradeoffs between memory, compute, and I/O. The biggest wins come from: (1) avoiding shuffle, (2) broadcasting small tables, (3) handling skew, and (4) using AQE. Always profile before optimizing.