Join Strategies: Sort-Merge, Broadcast, Shuffle Hash, Sort
Difficulty: Expert | Companies: Meta, Google, Netflix, Uber, Apple
βΉοΈInterview Context
Join strategy selection is fundamental to Spark performance. Interviewers expect understanding of when each strategy is used, their mathematical complexity, and how to influence Spark's choice.
Question
Compare all Spark join strategies: Sort-Merge, Broadcast Hash, Shuffle Hash, and Broadcast Nested Loop. What determines which strategy Spark chooses? How do you handle joins with skewed data? Provide mathematical complexity analysis for each strategy.
Detailed Answer
1. Join Strategy Selection Flow
2. Broadcast Hash Join
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder \
.appName("JoinStrategies") \
.config("spark.sql.autoBroadcastJoinThreshold", "10m") \
.getOrCreate()
# Broadcast Join: Small table sent to all executors
# No shuffle required for either side
large_df = spark.read.parquet("s3://data/transactions/") # 10 GB
small_df = spark.read.parquet("s3://data/currencies/") # 5 MB
# Explicit broadcast (forces broadcast even if over threshold)
result = large_df.join(F.broadcast(small_df), "currency_code")
# Verify in plan:
result.explain(True)
# == Physical Plan ==
# *(1) Project [id, amount, currency_code, currency_name]
# +- *(1) BroadcastHashJoin [currency_code], [currency_code], BuildRight, false
# :- *(1) FileScan parquet [id, amount, currency_code]
# +- BroadcastExchange HashedRelationBroadcastMode
# +- *(1) FileScan parquet [currency_code, currency_name]
# Mathematical analysis:
# Let L = large table size, S = small table size
# Broadcast Join Cost:
# 1. Collect small table to driver: O(S)
# 2. Serialize and broadcast: O(S Γ num_executors)
# 3. Build hash table per executor: O(S)
# 4. Probe hash table for each large row: O(L)
# Total: O(S Γ E + L) where E = number of executors
#
# Memory requirement per executor:
# Hash table size = S Γ (1 + load_factor) Γ key_size
# Typical: 5MB Γ 1.5 Γ 16 bytes = 120 MB per executor
#
# When to use:
# S < autoBroadcastJoinThreshold (default 10MB)
# AND S Γ E Γ 1.5 < executor_memory (memory constraint)
# AND num_executors is reasonable (serialization overhead)
# Broadcast Join Configuration:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50m") # 50 MB
spark.conf.set("spark.sql.broadcastExchange.maxParallelism", "10")
# For large clusters with many executors:
# Broadcast overhead = S Γ E (serialized copies)
# For S = 50MB, E = 1000: 50GB total network traffic
# May be slower than sort-merge join!
# Optimal broadcast size:
# S_max = min(threshold, executor_memory Γ 0.3 / E)
# For 8GB executors, 1000 executors: S_max = min(50MB, 2.4MB)
# Even though threshold allows 50MB, memory constraint limits to 2.4MB
# Alternative: Use adaptive broadcast
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
# AQE can convert sort-merge to broadcast at runtime if statistics show
# one side is small enough
3. Sort-Merge Join
# Sort-Merge Join: Default for large-large joins
# Both sides are sorted by join key, then merged
# Process:
# 1. Both sides shuffled by join key (same partitioner)
# 2. Each partition sorted by key
# 3. Merge step: linear scan through both sorted sequences
result = large_df1.join(large_df2, "user_id") # both large tables
result.explain(True)
# == Physical Plan ==
# *(3) SortMergeJoin [user_id], [user_id], Inner
# :- Sort [user_id ASC], false, 0
# : +- ShuffleExchange RoundRobinPartitioning(200)
# : +- *(1) FileScan parquet [user_id, ...]
# +- Sort [user_id ASC], false, 0
# +- ShuffleExchange RoundRobinPartitioning(200)
# +- *(2) FileScan parquet [user_id, ...]
# Mathematical analysis:
# Let N = left table rows, M = right table rows
# P = number of partitions
#
# Step 1: Shuffle
# Left: O(N) write + O(N) read = O(N)
# Right: O(M) write + O(M) read = O(M)
# Total shuffle: O(N + M)
#
# Step 2: Sort per partition
# Left: O(N/P Γ log(N/P)) per partition
# Right: O(M/P Γ log(M/P)) per partition
# Total sort: O(N log(N/P) + M log(M/P))
#
# Step 3: Merge
# O(N/P + M/P) per partition
# Total merge: O(N + M)
#
# Total cost: O(N + M + N log(N/P) + M log(M/P))
# = O(N log N + M log M) approximately
# Sort-Merge Join with Bloom Filter (Spark 3.x):
# Spark 3.x adds Bloom filter optimization:
# 1. First pass: collect distinct keys from smaller side
# 2. Build Bloom filter
# 3. Second pass: filter out non-matching keys before shuffle
# Can reduce shuffle by 50-90% for selective joins
4. Shuffle Hash Join
# Shuffle Hash Join: Good for medium-sized tables
# One side is hashed (not broadcast), other side probed
# Spark chooses Shuffle Hash Join when:
# 1. Neither side is small enough to broadcast
# 2. One side is significantly smaller than the other
# 3. Join keys are unique or nearly unique
# Configuration:
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")
spark.conf.set("spark.sql.shuffle.partitions", "200")
# Force Shuffle Hash Join:
result = large_df.join(small_df, "user_id", "hint") # using hint
# Or via SQL:
result = spark.sql("""
SELECT /*+ SHUFFLE_HASH(small_table) */ *
FROM large_table
JOIN small_table ON large_table.id = small_table.id
""")
# Mathematical analysis:
# Let N = larger side, M = smaller side
# P = partitions
#
# Step 1: Shuffle both sides
# O(N + M)
#
# Step 2: Build hash table for smaller side per partition
# Build: O(M/P Γ hash_cost)
# Memory: O(M/P Γ key_size Γ load_factor)
#
# Step 3: Probe hash table for larger side
# Probe: O(N/P Γ hash_cost)
#
# Total: O(N + M)
# Memory: O(M/P Γ key_size Γ 1.5)
#
# Comparison with Sort-Merge:
# Sort-Merge: O(N log N + M log M)
# Shuffle Hash: O(N + M)
# Shuffle Hash is better when M is not too large
# Shuffle Hash Join Memory Constraint:
# M/P Γ key_size Γ 1.5 < partition_memory
# For M = 1GB, P = 200: M/P = 5MB per partition
# With key_size = 16 bytes, load_factor = 1.5: 120 MB per partition
# Fits in typical partition memory
5. Broadcast Nested Loop Join
# Broadcast Nested Loop Join: For non-equi joins (>, <, !=, BETWEEN)
# No index available β must compare every pair
# Example: range join
result = large_df.join(
F.broadcast(small_df),
F.expr("large_df.event_time BETWEEN small_df.start_time AND small_df.end_time")
)
# Mathematical analysis:
# Let L = large table size, S = small table size (broadcast)
# E = number of executors
#
# Cost:
# Broadcast: O(S Γ E)
# Nested loop: O(L Γ S) per executor (worst case)
# Total: O(L Γ S)
#
# Optimization: Spark can use broadcast with Bloom filter
# for approximate joins to reduce comparisons
#
# Example: L = 10M rows, S = 1000 rows
# Without optimization: 10M Γ 1000 = 10B comparisons
# With Bloom filter (10% match rate): 1M Γ 1000 = 1B comparisons
# When to use:
# 1. Non-equi joins (BETWEEN, >, <, !=)
# 2. Small right side (< broadcast threshold)
# 3. No suitable index for nested loop
# Performance optimization:
# 1. Use broadcast to avoid shuffle
# 2. Filter early to reduce L
# 3. Use partitioning to parallelize
6. Skew Join Handling
# Skew Detection and Handling:
import numpy as np
def analyze_join_skew(df, key_col, sample_fraction=0.1):
"""Analyze key distribution for join skew detection."""
sample = df.sample(sample_fraction)
key_counts = sample.groupBy(key_col).count().collect()
counts = [row["count"] for row in key_counts]
if not counts:
return {"skewed": False}
mean_count = np.mean(counts)
median_count = np.median(counts)
max_count = np.max(counts)
p95_count = np.percentile(counts, 95)
skew_ratio = max_count / median_count if median_count > 0 else float('inf')
return {
"skewed": skew_ratio > 5,
"skew_ratio": skew_ratio,
"max_count": max_count,
"median_count": median_count,
"mean_count": mean_count,
"p95_count": p95_count
}
# AQE Skew Join (automatic):
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")
# Manual Skew Handling:
def skew_aware_join(left_df, right_df, join_key, salt_range=10):
"""Join with salting to handle skew."""
# Detect skewed keys
key_stats = analyze_join_skew(left_df, join_key)
if not key_stats["skewed"]:
return left_df.join(right_df, join_key)
# Find skewed keys
key_counts = left_df.groupBy(join_key).count()
skewed_keys = key_counts.filter(
F.col("count") > key_stats["median_count"] * 5
).select(join_key).collect()
skewed_key_list = [row[join_key] for row in skewed_keys]
# Split data
normal_left = left_df.filter(~F.col(join_key).isin(skewed_key_list))
skewed_left = left_df.filter(F.col(join_key).isin(skewed_key_list))
# Add salt to skewed data
skewed_left_salted = skewed_left.withColumn(
"salt", (F.rand() * salt_range).cast("int")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Replicate right side for salted keys
salt_sequence = F.sequence(F.lit(0), F.lit(salt_range - 1))
right_salted = right_df.filter(
F.col(join_key).isin(skewed_key_list)
).crossJoin(
salt_sequence.alias("salt")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Normal join (no skew)
normal_result = normal_left.join(right_df, join_key)
# Salted join (handles skew)
skewed_result = skewed_left_salted.join(right_salted, "salted_key") \
.drop("salt", "salted_key")
# Union results
return normal_result.unionByName(skewed_result, allowMissingColumns=True)
7. Join Hint Usage
# Spark provides hints to influence join strategy selection:
# 1. Broadcast Hint
result = large_df.join(small_df.hint("broadcast"), "id")
# Forces BroadcastHashJoin
# 2. Shuffle Hash Hint
result = large_df.join(small_df.hint("shuffle_hash"), "id")
# Forces ShuffleHashJoin
# 3. Sort Merge Hint
result = large_df.join(small_df.hint("shuffle_sort_merge_join"), "id")
# Forces SortMergeJoin
# 4. Shuffle Repartition Hint
result = large_df.join(small_df.hint("shuffle_repartition"), "id")
# Forces repartition before join
# SQL hints:
result = spark.sql("""
SELECT /*+ BROADCAST(small_table) */ *
FROM large_table
JOIN small_table ON large_table.id = small_table.id
""")
# AQE overrides hints at runtime if statistics show better plan
spark.conf.set("spark.sql.adaptive.enabled", "true")
# AQE may convert BroadcastHashJoin β SortMergeJoin if broadcast too large
# AQE may convert SortMergeJoin β BroadcastHashJoin if one side is small
β οΈCommon Pitfall
Using broadcast() hint on a table that's too large causes OOM in executors. Always verify table size before broadcasting. The hint bypasses the automatic threshold check.
π‘Interview Tip
When discussing join strategies, always mention the memory-time tradeoff: Broadcast join trades memory (hash table) for time (no sort). Sort-Merge join trades time (sorting) for memory (no hash table).
Summary
| Strategy | Complexity | Shuffle Required | Memory Usage | Best For |
|---|---|---|---|---|
| Broadcast Hash | O(N) | No | O(S Γ E) | Small-large joins |
| Sort-Merge | O(N log N + M log M) | Yes (both sides) | O(N/P) | Large-large joins |
| Shuffle Hash | O(N + M) | Yes (both sides) | O(M/P Γ k) | Medium-medium joins |
| Broadcast Nested Loop | O(N Γ M) | No | O(S Γ E) | Non-equi joins |
The key to optimal join performance is understanding data sizes, key distributions, and letting Catalyst make informed decisions with accurate statistics.