Data Skew Solutions: Salting, AQS Skew Join, Bucketing
Difficulty: Expert | Companies: Meta, Uber, Netflix, Airbnb, Stripe
βΉοΈInterview Context
Data skew is the most common cause of slow Spark jobs. Interviewers expect understanding of skew detection, multiple mitigation strategies, and when to use each approach.
Question
What is data skew and how does it impact Spark performance? Explain multiple strategies to handle data skew: salting, AQE skew join, isolate-and-process, and bucketing. How do you detect skew programmatically? Provide mathematical analysis of skew impact.
Detailed Answer
1. Data Skew β Mathematical Analysis
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
spark = SparkSession.builder \
.appName("DataSkewSolutions") \
.config("spark.sql.shuffle.partitions", "200") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.getOrCreate()
# Data skew definition:
# A key has significantly more records than other keys
# Causes uneven partition sizes β straggler tasks
# Mathematical impact:
# Let:
# N = total rows
# K = number of distinct keys
# P = number of partitions
# S = skew factor (max_key_count / avg_key_count)
#
# Without skew (uniform distribution):
# Partition size = N / P
# Task time = T_compute(N / P)
# Total time = T_compute(N / P) (parallel execution)
#
# With skew (S = 10, meaning one key has 10x average):
# Skewed partition size = N / P Γ S
# Task time = T_compute(N / P Γ S)
# Total time = T_compute(N / P Γ S) (bottleneck task)
#
# Speedup from fixing skew: S (10x in this example)
# Example: join with skew
left_df = spark.range(10000000).withColumn(
"key", F.when(F.col("id") < 1000000, F.lit("hot_key"))
.otherwise(F.concat(F.lit("key_"), F.col("id") % 9000000))
)
right_df = spark.range(1000000).withColumn(
"key", F.concat(F.lit("key_"), F.col("id") % 1000000)
)
# hot_key has 1M records, other keys have ~1 each
# Skew factor: 1M / (10M / 1M) = 100x
2. Skew Detection
def detect_skew(df, key_col, sample_fraction=0.1, skew_threshold=5.0):
"""Detect data skew by sampling and analyzing key distribution."""
# Sample for efficiency
sample = df.sample(sample_fraction)
# Count per key
key_counts = sample.groupBy(key_col).count()
# Calculate statistics
stats = key_counts.select(
F.mean("count").alias("mean_count"),
F.expr("percentile_approx(count, 0.5)").alias("median_count"),
F.expr("percentile_approx(count, 0.95)").alias("p95_count"),
F.max("count").alias("max_count"),
F.count("*").alias("num_keys")
).collect()[0]
skew_ratio = stats["max_count"] / stats["median_count"]
print(f"Skew Analysis:")
print(f" Number of keys: {stats['num_keys']:,}")
print(f" Mean count: {stats['mean_count']:,.0f}")
print(f" Median count: {stats['median_count']:,.0f}")
print(f" P95 count: {stats['p95_count']:,.0f}")
print(f" Max count: {stats['max_count']:,}")
print(f" Skew ratio (max/median): {skew_ratio:.2f}")
# Find top skewed keys
top_skewed = key_counts.orderBy(F.desc("count")).limit(10)
print("\nTop 10 skewed keys:")
top_skewed.show()
return {
"is_skewed": skew_ratio > skew_threshold,
"skew_ratio": skew_ratio,
"top_keys": [row[key_col] for row in top_skewed.collect()]
}
# Detect skew in left_df
skew_info = detect_skew(left_df, "key")
# Output:
# Skew ratio (max/median): 1000000.00
# Top skewed keys: ["hot_key", "key_12345", ...]
3. Salting Technique
# Salting: Add random prefix to distribute skewed keys
# Process: Split skewed join into salted + normal parts
def salted_join(left_df, right_df, join_key, salt_range=100):
"""Join with salting to handle skew."""
# Detect skew
skew_info = detect_skew(left_df, join_key)
if not skew_info["is_skewed"]:
return left_df.join(right_df, join_key)
# Get skewed keys
skewed_keys = skew_info["top_keys"]
# Split data
normal_left = left_df.filter(~F.col(join_key).isin(skewed_keys))
skewed_left = left_df.filter(F.col(join_key).isin(skewed_keys))
# Add salt to skewed data
skewed_left_salted = skewed_left.withColumn(
"salt", (F.rand() * salt_range).cast("int")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Replicate right side for salted keys
salt_sequence = F.sequence(F.lit(0), F.lit(salt_range - 1))
right_salted = right_df.filter(
F.col(join_key).isin(skewed_keys)
).crossJoin(
salt_sequence.alias("salt")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Normal join (no skew)
normal_result = normal_left.join(right_df, join_key)
# Salted join (handles skew)
skewed_result = skewed_left_salted.join(
right_salted, "salted_key"
).drop("salt", "salted_key")
# Union results
return normal_result.unionByName(skewed_result, allowMissingColumns=True)
# Apply salted join
result = salted_join(left_df, right_df, "key")
# Salting mathematical analysis:
# Let S = skew factor, P = partitions, R = salt range
#
# Without salt:
# Skewed partition size: N Γ S / P
# Other partitions: N Γ (1 - S) / P
#
# With salt (R = 100):
# Skewed key distributed across R partitions
# Each salted partition: N Γ S / (P Γ R)
# Other partitions unchanged
#
# Improvement: S / R reduction in skewed partition size
# For S = 1000, R = 100: 10x reduction
# Salt range selection:
# R should be large enough to distribute skew
# But not too large (increases shuffle and memory)
#
# Rule of thumb: R = max(10, S / 10)
# For S = 1000: R = 100
# For S = 100: R = 10
4. AQE Skew Join (Automatic)
# Adaptive Query Execution (AQE) handles skew automatically
# No manual salting required!
# Configuration:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")
# How AQE skew join works:
# 1. Execute first stage (shuffle)
# 2. Collect partition statistics
# 3. Detect skewed partitions (size > factor Γ median)
# 4. Split skewed partitions into sub-partitions
# 5. Replicate other side to match
# 6. Execute join with balanced partitions
# AQE skew detection algorithm:
# 1. Calculate median partition size: median = P50(partition_sizes)
# 2. For each partition: if size > factor Γ median β skewed
# 3. Split skewed partition into ceil(size / median) sub-partitions
# 4. Add random prefix to keys in sub-partitions
# 5. Replicate matching rows from other side
# AQE vs manual salting:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 180" width="100%" style={{ maxWidth: 700 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="skew-hdr" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<filter id="skew-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
</filter>
</defs>
<rect x="10" y="10" width="780" height="160" rx="14" fill="#fff" filter="url(#skew-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
<rect x="10" y="10" width="780" height="30" rx="14" fill="url(#skew-hdr)"/>
<rect x="10" y="24" width="780" height="16" fill="url(#skew-hdr)"/>
<text x="160" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Aspect</text>
<text x="400" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">AQE Skew Join</text>
<text x="640" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Manual Salting</text>
<text x="160" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Detection</text>
<text x="400" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Automatic</text>
<text x="640" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Manual analysis</text>
<line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Implementation</text>
<text x="400" y="80" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Zero code change</text>
<text x="640" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Custom code</text>
<line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Salt range</text>
<text x="400" y="104" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Adaptive</text>
<text x="640" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Fixed</text>
<line x1="30" y1="114" x2="770" y2="114" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Performance</text>
<text x="400" y="128" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">2-10x improvement</text>
<text x="640" y="128" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">5-100x improvement</text>
<line x1="30" y1="138" x2="770" y2="138" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="152" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Compatibility</text>
<text x="400" y="152" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Spark 3.x only</text>
<text x="640" y="152" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">All versions</text>
<line x1="30" y1="162" x2="770" y2="162" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="168" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Edge cases</text>
<text x="400" y="168" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">May miss some</text>
<text x="640" y="168" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Handles all</text>
</svg>
</div>
# Example: AQE skew join
result = left_df.join(right_df, "key") # AQE handles skew automatically
# Verify in Spark UI:
# Stage tab β Show Details β Look for "Skew Join" in plan
5. Isolate and Process
# Isolate skewed keys and process separately
# Best for: known hot keys (e.g., NULL, default values)
def isolate_skewed_keys(df, key_col, skewed_keys):
"""Split data into normal and skewed parts."""
normal_df = df.filter(~F.col(key_col).isin(skewed_keys))
skewed_df = df.filter(F.col(key_col).isin(skewed_keys))
return normal_df, skewed_df
# Example: process NULL keys separately
skewed_keys = [None, "unknown", "default"]
normal_left, skewed_left = isolate_skewed_keys(left_df, "key", skewed_keys)
normal_right, skewed_right = isolate_skewed_keys(right_df, "key", skewed_keys)
# Process normally (no skew)
normal_result = normal_left.join(normal_right, "key")
# Process skewed with special handling
# Option 1: Broadcast small skewed side
if skewed_right.count() < 1000000: # Small enough to broadcast
skewed_result = skewed_left.join(
F.broadcast(skewed_right), "key"
)
else:
# Option 2: Repartition with more partitions
skewed_result = skewed_left.repartition(1000, "key") \
.join(skewed_right.repartition(1000, "key"), "key")
# Union results
final_result = normal_result.unionByName(skewed_result, allowMissingColumns=True)
6. Bucketing for Repeated Joins
# Bucketing: Pre-partition data by key at write time
# Eliminates shuffle for repeated joins on same key
# Write bucketed tables
left_df.write \
.bucketBy(200, "key") \
.sortBy("key") \
.mode("overwrite") \
.saveAsTable("left_bucketed")
right_df.write \
.bucketBy(200, "key") \
.sortBy("key") \
.mode("overwrite") \
.saveAsTable("right_bucketed")
# Join bucketed tables (no shuffle!)
left_bucketed = spark.table("left_bucketed")
right_bucketed = spark.table("right_bucketed")
result = left_bucketed.join(right_bucketed, "key")
# Bucketing mathematical analysis:
# Let P = bucket count, N = rows
# Each bucket: N / P rows (if uniformly distributed)
# Join: local join within each bucket (no shuffle)
#
# Bucketing vs shuffle join:
# Shuffle join: O(N + M) network I/O
# Bucketed join: O(0) network I/O
#
# Bucketing vs salting:
# Salting: runtime salt addition, temporary
# Bucketing: write-time distribution, permanent
# Bucketing best practices:
# 1. Use same bucket count for both tables
# 2. Use same bucket key for both tables
# 3. Bucket count should be β₯ executor cores for parallelism
# 4. Target bucket size: 128-256 MB
# Bucket count calculation:
num_executors = 50
cores_per_executor = 4
total_cores = num_executors * cores_per_executor # 200
# Use total_cores as bucket count for full parallelism
# Or use data_size / target_bucket_size
data_size_gb = 100
target_bucket_mb = 128
optimal_buckets = int(data_size_gb * 1024 / target_bucket_mb) # 800
# Use max(total_cores, optimal_buckets) for both parallelism and size
bucket_count = max(total_cores, optimal_buckets) # 800
7. Hybrid Approaches
# Combine multiple strategies for complex skew scenarios
def hybrid_skew_solution(left_df, right_df, join_key):
"""Hybrid approach: AQE + salting + isolation."""
# Step 1: Enable AQE for automatic handling
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# Step 2: Isolate known hot keys
hot_keys = [None, "unknown"]
normal_left = left_df.filter(~F.col(join_key).isin(hot_keys))
hot_left = left_df.filter(F.col(join_key).isin(hot_keys))
normal_right = right_df.filter(~F.col(join_key).isin(hot_keys))
hot_right = right_df.filter(F.col(join_key).isin(hot_keys))
# Step 3: Process normal data (AQE handles remaining skew)
normal_result = normal_left.join(normal_right, join_key)
# Step 4: Process hot keys with salting
if hot_left.count() > 0:
# Add salt to hot keys
salt_range = 100
hot_left_salted = hot_left.withColumn(
"salt", (F.rand() * salt_range).cast("int")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Replicate right side
salt_sequence = F.sequence(F.lit(0), F.lit(salt_range - 1))
hot_right_salted = hot_right.crossJoin(
salt_sequence.alias("salt")
).withColumn(
"salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
)
# Salted join
hot_result = hot_left_salted.join(
hot_right_salted, "salted_key"
).drop("salt", "salted_key")
# Union all results
return normal_result.unionByName(hot_result, allowMissingColumns=True)
return normal_result
# Apply hybrid solution
result = hybrid_skew_solution(left_df, right_df, "key")
β οΈCommon Pitfall
Salting adds overhead for non-skewed data. Always check if skew exists before applying salting. Use skew detection function to determine if intervention is needed.
π‘Interview Tip
When discussing skew solutions, always mention that AQE is the first line of defense in Spark 3.x. Manual salting is for cases where AQE doesn't handle skew well (e.g., extreme skew, complex join patterns).
Summary
| Strategy | When to Use | Implementation | Improvement |
|---|---|---|---|
| AQE Skew Join | Default for Spark 3.x | Zero config | 2-10x |
| Salting | Extreme skew, AQE insufficient | Manual code | 5-100x |
| Isolate Keys | Known hot keys | Split + special handling | 10-100x |
| Bucketing | Repeated joins on same key | Write-time distribution | Eliminates shuffle |
| Hybrid | Complex scenarios | Combine strategies | Maximum improvement |
The key to solving data skew is: detect first, then apply the simplest solution that works. Start with AQE, escalate to salting if needed, and use bucketing for repeated operations.