π PySpark Joins Optimization
DfJoin (Relational Algebra)
A join combines two datasets based on a common key. In Spark, joins are implemented as wide transformations requiring shuffle unless one side is broadcast. The join strategy (broadcast, sort-merge, shuffle-hash) is selected by Catalyst based on data statistics.
DfBroadcast Join
A broadcast join sends the smaller dataset to all executors via broadcast variables, eliminating shuffle on the larger side. Effective when one dataset fits in executor memory (typically < 10MB, configurable via autoBroadcastJoinThreshold).
DfSort-Merge Join
A sort-merge join sorts both datasets by the join key, then merges them in a single pass. It is the default strategy for large-large joins and scales linearly with data size.
Here,
- =Join selectivity (fraction of row pairs that match)
- =Number of matching row pairs after join
- =Number of rows in table A
- =Number of rows in table B
Broadcast Join Threshold
Here,
- =Effective broadcast threshold in bytes
- =Executor memory available for broadcast
- =Safety factor (default 0.25 β use 25% of executor memory)
- =Configured spark.sql.autoBroadcastJoinThreshold
Sort-Merge Join Cost
Here,
- =Total sort-merge join cost
- =Size of dataset A (rows or bytes)
- =Size of dataset B (rows or bytes)
Catalyst selects join strategy based on estimated table sizes: if the smaller side is below autoBroadcastJoinThreshold, it uses BroadcastHashJoin; otherwise, it defaults to SortMergeJoin (most general) or ShuffleHashJoin (for medium-sized data).
For skewed joins, use skewJoinHint or AQE's skew join handling. AQE automatically detects skewed partitions and splits them into sub-partitions to balance the workload.
ThBroadcast Join Optimization
Theorem: If one side of a join is broadcast, the shuffle cost is reduced from O(P Γ N Γ W) to O(P Γ N_{small} Γ W) where P is the number of partitions on the large side, N is row count, W is row width, and N_{small} is the small table row count. This provides a speedup of N_{large} / (N_{large} + N_{small} Γ P).
- Broadcast joins eliminate shuffle when one side fits in memory
- Default broadcast threshold is 10MB; increase to 50-100MB for large clusters
- Sort-merge join is the most general strategy; requires both sides to be sorted
- Bucket joins eliminate shuffle for repeated joins on the same key
- AQE handles skewed joins automatically at runtime
- Filter before join to reduce data size before shuffle
- Use
left_semiinstead ofinnerwhen you only need left-table columns
Join Strategy Selection Flowchart
Broadcast vs Sort-Merge Join
Architecture Diagram
Join Types Overview
Join Types are classified into four categories:
Equi Joins β Join on equality condition (=):
- Inner Join β Returns only matching rows from both tables
- Left Outer β All left rows + matching right rows (NULL if no match)
- Right Outer β All right rows + matching left rows (NULL if no match)
- Full Outer β All rows from both tables (NULL if no match on either side)
Semi/Anti Joins β Filter-based joins:
- Left Semi β Returns left rows that have a match in right (no right columns included)
- Left Anti β Returns left rows that have NO match in right
Cross Join β Cartesian product:
- Cartesian Product β Every row from A paired with every row from B (use with caution)
Optimization Strategies β How Spark executes joins:
- Broadcast Join β Small table sent to all executors (no shuffle)
- Sort-Merge Join β Both sides shuffled, sorted, then merged (default for large tables)
- Bucket Join β Pre-bucketed tables avoid shuffle on repeated joins
π Detailed Explanation
1. Join Types in PySpark
PySpark supports several join types, each with different semantics:
| Join Type | Description | Use Case |
|---|---|---|
| Inner Join | Returns only rows that have matching keys in both DataFrames | Default, most common |
| Left Outer Join | All left rows + matching right rows (NULL if no match) | Keep all left records |
| Right Outer Join | All right rows + matching left rows (NULL if no match) | Keep all right records |
| Full Outer Join | All rows from both tables (NULL if no match on either side) | Complete data merge |
| Left Semi Join | Left rows that have a match in right (no right columns) | EXISTS check |
| Left Anti Join | Left rows that have NO match in right | NOT EXISTS check |
| Cross Join | Every row from A paired with every row from B | Cartesian product (avoid!) |
2. Broadcast Hash Join
Broadcast hash join is the most efficient strategy when one table fits in memory.
Configuration:
# Default threshold: 10MB
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
# Manual broadcast hint
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")
Advantages:
- No shuffle required
- Parallel execution on all executors
- Low memory overhead for small table
- Very fast for small-large joins
Limitations:
- Small table must fit in executor memory
- Network overhead for broadcasting
- Not suitable for large-small joins
3. Sort-Merge Join
Sort-merge join is the default strategy for joining large tables.
Three Phases:
- Shuffle: Both tables are partitioned by key
- Sort: Each partition is sorted by key
- Merge: Sorted partitions are merged
Optimization Techniques:
- Partition pruning: Skip partitions that don't match
- Bucketing: Pre-partition data by join key
- Sort merge join with Bloom filter: Skip non-matching keys
4. Shuffle Hash Join
Shuffle hash join is used when one table is moderately sized.
When to Use:
- Medium-sized tables (10MB - 1GB)
- High selectivity joins
- When broadcast join is not possible
5. Cartesian Product Join
Cartesian product joins produce all combinations of rows from both tables.
- Complexity: O(N Γ M)
- Use cases: Only when business logic requires all combinations
- Warning: Extremely expensive β avoid unless absolutely necessary
6. Join Optimization Strategies
Broadcast Hints:
result = large_df.join(broadcast(small_df), "key")
Bucketing:
df.write.bucketBy(100, "user_id").sortBy("user_id").saveAsTable("users_bucketed")
df.write.bucketBy(100, "user_id").sortBy("user_id").saveAsTable("orders_bucketed")
Partitioning:
df = df.repartition(100, "user_id")
7. Data Skew in Joins
Data skew occurs when some keys have significantly more data than others.
Detection:
- Monitor task duration in Spark UI
- Look for tasks with much longer duration
- Check shuffle read/write metrics
Mitigation:
- Salting: Add random prefix to skewed keys
- Broadcast join: Avoid shuffle for skewed tables
- AQE: Adaptive Query Execution in Spark 3.0+ handles skew automatically
8. Join Order Optimization
The order of joins can significantly impact performance:
- Join smaller tables first to reduce intermediate results
- Use broadcast joins for small-large joins
- Consider join cardinality (1:1, 1:N, N:M)
Best Practice: Use
left_semiinstead ofinnerwhen you only need left-table columns β it avoids duplicate columns and is faster.
π Key Concepts Table
| Join Type | Description | Shuffle? | Use Case |
|---|---|---|---|
| Inner Join | Only matching rows from both | Yes | Default, most common |
| Left Outer | All left + matching right | Yes | Keep all left records |
| Right Outer | All right + matching left | Yes | Keep all right records |
| Full Outer | All from both tables | Yes | Complete data merge |
| Left Semi | Left rows with match in right | Yes | EXISTS check |
| Left Anti | Left rows without match in right | Yes | NOT EXISTS check |
| Cross Join | Cartesian product | Yes | All combinations |
| Broadcast Join | Small table broadcast | No | Small + large table |
| Sort-Merge Join | Sort both, then merge | Yes | Large + large table |
| Shuffle Hash Join | Hash table in memory | Yes | Medium + medium |
π» Code Examples
Example 1: Basic Join Types
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
spark = SparkSession.builder.appName("JoinOptimization").getOrCreate()
# Create sample DataFrames
employees = spark.createDataFrame([
(1, "Alice", "Engineering"),
(2, "Bob", "Marketing"),
(3, "Charlie", "Engineering"),
(4, "Diana", "Sales"),
(5, "Eve", None)
], ["id", "name", "department"])
# Schema: id INT, name STRING, department STRING (nullable)
departments = spark.createDataFrame([
("Engineering", "San Francisco", 50),
("Marketing", "New York", 30),
("Sales", "Chicago", 40),
("HR", "Boston", 20)
], ["dept_name", "location", "headcount"])
# Schema: dept_name STRING, location STRING, headcount INT
# Inner Join
# Parameters: other DataFrame, join expression, join type
# join(other, on, how) β how defaults to "inner"
inner = employees.join(departments, employees.department == departments.dept_name, "inner")
# Returns only rows where department matches dept_name in both tables
print("Inner Join:")
inner.show()
# Left Outer Join
# Parameter: "left" β keep all rows from left DataFrame
left = employees.join(departments, employees.department == departments.dept_name, "left")
# Eve (department=None) will appear with null right columns
print("Left Outer Join:")
left.show()
# Right Outer Join
# Parameter: "right" β keep all rows from right DataFrame
right = employees.join(departments, employees.department == departments.dept_name, "right")
# HR (dept_name="HR") will appear with null left columns
print("Right Outer Join:")
right.show()
# Full Outer Join
# Parameter: "full" β keep all rows from both DataFrames
full = employees.join(departments, employees.department == departments.dept_name, "full")
print("Full Outer Join:")
full.show()
# Left Semi Join
# Parameter: "left_semi" β returns only left columns where match exists
# No duplicate columns, no right-side nulls
semi = employees.join(departments, employees.department == departments.dept_name, "left_semi")
print("Left Semi Join:")
semi.show()
# Left Anti Join
# Parameter: "left_anti" β returns only left columns where NO match exists
# Equivalent to: SELECT * FROM employees WHERE department NOT IN (SELECT dept_name FROM departments)
anti = employees.join(departments, employees.department == departments.dept_name, "left_anti")
print("Left Anti Join:")
anti.show()
Example 2: Broadcast Join
from pyspark.sql.functions import broadcast, col
# Create large and small DataFrames
large_df = spark.range(1000000).withColumn("key", col("id") % 1000)
# 1 million rows, keys 0-999
small_df = spark.createDataFrame([
(i, f"category_{i}") for i in range(100)
], ["key", "category"])
# 100 rows β small enough to broadcast
# Method 1: Broadcast hint
# Parameter: broadcast(df) β marks DataFrame for broadcast
result = large_df.join(broadcast(small_df), "key")
# Catalyst will broadcast small_df to all executors
# Method 2: Configure auto broadcast threshold
# Parameter: spark.sql.autoBroadcastJoinThreshold
# Default: 10485760 (10MB)
# Set to -1 to disable auto-broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
result = large_df.join(small_df, "key")
# Catalyst auto-selects BroadcastHashJoin if small_df < 10MB
# Check execution plan
result.explain()
# Output shows BroadcastHashJoin
# == Physical Plan ==
# *(2) BroadcastHashJoin [key], [key], Inner, BuildLeft
# :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]))
# : *(1) Scan Parquet [id#0L, key#1L]
# *(2) BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, true]))
# : *(1) Scan Parquet [key#2L, category#3]
Example 3: Sort-Merge Join with Bucketing
# Create bucketed tables
users = spark.createDataFrame([
(i, f"user_{i}", i % 10) for i in range(100000)
], ["user_id", "name", "dept_id"])
orders = spark.createDataFrame([
(i, i % 100000, i * 10.0) for i in range(1000000)
], ["order_id", "user_id", "amount"])
# Write bucketed tables
# Parameters:
# bucketBy(20, "user_id") β 20 hash buckets by user_id
# sortBy("user_id") β sort within each bucket
# saveAsTable("users_bucketed") β persist as Hive table
users.write \
.bucketBy(20, "user_id") \
.sortBy("user_id") \
.saveAsTable("users_bucketed")
orders.write \
.bucketBy(20, "user_id") \
.sortBy("user_id") \
.saveAsTable("orders_bucketed")
# Read bucketed tables
users_bucketed = spark.table("users_bucketed")
orders_bucketed = spark.table("orders_bucketed")
# Join bucketed tables (no shuffle needed!)
# Both tables are bucketed by same key with same number of buckets
result = users_bucketed.join(orders_bucketed, "user_id")
result.explain()
# Check that no shuffle occurs
# Output shows SortMergeJoin without Exchange
Example 4: Handling Data Skew
from pyspark.sql.functions import *
# Create skewed data
skewed_data = spark.createDataFrame(
[(i, f"user_{i % 10}") for i in range(1000000)] +
[(i + 1000000, "skewed_user") for i in range(100000)], # Skewed key
["id", "user_id"]
)
# "skewed_user" has 100K rows vs 10 rows for other keys
user_data = spark.createDataFrame([
(f"user_{i}", f"Name {i}") for i in range(10)
] + [("skewed_user", "Skewed User")],
["user_id", "name"]
)
# Method 1: Broadcast if possible
# If user_data is small enough, broadcast eliminates skew entirely
result = skewed_data.join(broadcast(user_data), "user_id")
# Method 2: Salting (add random prefix to skewed keys)
# Steps:
# 1. Identify skewed keys
# 2. Add random salt to skewed keys in both DataFrames
# 3. Join on salted keys
# Add salt to skewed data
salted = skewed_data.withColumn(
"salt", # new column name
when(
col("user_id") == "skewed_user", # condition for skewed key
(rand() * 10).cast("int") # random salt 0-9
).otherwise(0) # non-skewed keys get salt=0
).withColumn(
"salted_key", # combined key
concat(col("user_id"), lit("_"), col("salt"))
# "skewed_user_3", "user_1_0", etc.
)
# Expand user data with salts
# Each non-skewed row gets salt=0, skewed rows get salts 0-9
user_with_salt = user_data.crossJoin(
spark.range(10).withColumnRenamed("id", "salt")
# Creates 10 copies of each user_data row
).withColumn(
"salted_key",
concat(col("user_id"), lit("_"), col("salt"))
)
# Join on salted keys
result = salted.join(user_with_salt, "salted_key")
result.explain()
Example 5: Join with Complex Conditions
from pyspark.sql.functions import *
# Create DataFrames with multiple join keys
orders = spark.createDataFrame([
(1, "2024-01-15", "electronics", 100.0),
(2, "2024-01-15", "clothing", 50.0),
(3, "2024-02-20", "electronics", 200.0),
], ["order_id", "order_date", "category", "amount"])
targets = spark.createDataFrame([
("2024-01", "electronics", 500.0),
("2024-01", "clothing", 300.0),
("2024-02", "electronics", 800.0),
], ["month", "category", "target"])
# Join on multiple conditions
# Parameter: list of column names β auto-generates equality conditions
result_multi = orders.join(
targets,
["category"], # join on category only
"inner"
)
# Join with complex expression
# Parameter: SQL-like condition string
result_complex = orders.join(
targets,
(orders.category == targets.category) &
(orders.order_date.substr(1, 7) == targets.month), # substring match
"inner"
)
# Left semi join to filter orders meeting target
result_filtered = orders.join(
targets,
(orders.category == targets.category) &
(orders.amount < targets.target),
"left_semi"
)
# Returns only orders where amount < target for matching category
π Performance Metrics
| Join Type | 1GB + 10MB | 1GB + 1GB | 10GB + 10GB | Shuffle Size |
|---|---|---|---|---|
| Broadcast Join | 2.5s | N/A | N/A | 0 MB |
| Sort-Merge Join | 8.5s | 12.0s | 45.0s | 2x input |
| Shuffle Hash Join | 6.0s | 9.0s | N/A | 2x input |
| Broadcast (SQL) | 2.0s | N/A | N/A | 0 MB |
| Bucket Join | 4.0s | 6.0s | 25.0s | 0 MB |
| Left Semi Join | 5.0s | 8.0s | 30.0s | 1x input |
| Left Anti Join | 4.5s | 7.0s | 28.0s | 1x input |
| Cross Join | 120.0s | 1200.0s | N/A | N/A |
| Metric | 1GB Data | 10GB Data | 100GB Data | Notes |
|---|---|---|---|---|
| Broadcast Threshold | 10MB | 50MB | 100MB | Increase for large clusters |
| Optimal Partitions | 8-16 | 80-160 | 800-1600 | 128MB per partition |
| Skew Ratio | < 2x | < 2x | < 2x | > 2x indicates skew |
β Best Practices
1. Use Broadcast Joins for Small Tables
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")
# Or configure threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
2. Bucket Tables for Repeated Joins
# Write bucketed tables
# bucketBy(numBuckets, *cols) β hash partition into fixed buckets
# sortBy(*cols) β sort within buckets for merge efficiency
df1.write.bucketBy(100, "key").sortBy("key").saveAsTable("t1_bucketed")
df2.write.bucketBy(100, "key").sortBy("key").saveAsTable("t2_bucketed")
# Join bucketed tables without shuffle
result = spark.table("t1_bucketed").join(spark.table("t2_bucketed"), "key")
3. Handle Data Skew
# Broadcast if possible
result = skewed_df.join(broadcast(small_df), "key")
# Or use salting for large skewed joins
salted_df = skewed_df.withColumn("salt", (rand() * 10).cast("int"))
4. Choose Correct Join Type
# Use left_semi instead of inner when you only need left table columns
result = df1.join(df2, "key", "left_semi") # Faster, no duplicate columns
# Use left_anti for NOT EXISTS
result = df1.join(df2, "key", "left_anti")
5. Filter Before Join
# Filter early to reduce data size
result = df1.filter(col("age") > 30).join(df2, "key")
6. Monitor Join Performance
# Check execution plan
result.explain(True)
# Look for:
# - BroadcastHashJoin (good for small tables)
# - SortMergeJoin (default for large tables)
# - Exchange (shuffle operations)