πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Data Skew Solutions: Salting, AQS Skew Join, Bucketing

Apache SparkData Skew⭐ Premium

Advertisement

Data Skew Solutions: Salting, AQS Skew Join, Bucketing

Difficulty: Expert | Companies: Meta, Uber, Netflix, Airbnb, Stripe

ℹ️Interview Context

Data skew is the most common cause of slow Spark jobs. Interviewers expect understanding of skew detection, multiple mitigation strategies, and when to use each approach.

Question

What is data skew and how does it impact Spark performance? Explain multiple strategies to handle data skew: salting, AQE skew join, isolate-and-process, and bucketing. How do you detect skew programmatically? Provide mathematical analysis of skew impact.


Detailed Answer

1. Data Skew β€” Mathematical Analysis

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

spark = SparkSession.builder \
    .appName("DataSkewSolutions") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .getOrCreate()

# Data skew definition:
# A key has significantly more records than other keys
# Causes uneven partition sizes β†’ straggler tasks

# Mathematical impact:
# Let:
# N = total rows
# K = number of distinct keys
# P = number of partitions
# S = skew factor (max_key_count / avg_key_count)
#
# Without skew (uniform distribution):
#   Partition size = N / P
#   Task time = T_compute(N / P)
#   Total time = T_compute(N / P) (parallel execution)
#
# With skew (S = 10, meaning one key has 10x average):
#   Skewed partition size = N / P Γ— S
#   Task time = T_compute(N / P Γ— S)
#   Total time = T_compute(N / P Γ— S) (bottleneck task)
#
# Speedup from fixing skew: S (10x in this example)

# Example: join with skew
left_df = spark.range(10000000).withColumn(
    "key", F.when(F.col("id") < 1000000, F.lit("hot_key"))
            .otherwise(F.concat(F.lit("key_"), F.col("id") % 9000000))
)

right_df = spark.range(1000000).withColumn(
    "key", F.concat(F.lit("key_"), F.col("id") % 1000000)
)

# hot_key has 1M records, other keys have ~1 each
# Skew factor: 1M / (10M / 1M) = 100x

2. Skew Detection

def detect_skew(df, key_col, sample_fraction=0.1, skew_threshold=5.0):
    """Detect data skew by sampling and analyzing key distribution."""
    # Sample for efficiency
    sample = df.sample(sample_fraction)
    
    # Count per key
    key_counts = sample.groupBy(key_col).count()
    
    # Calculate statistics
    stats = key_counts.select(
        F.mean("count").alias("mean_count"),
        F.expr("percentile_approx(count, 0.5)").alias("median_count"),
        F.expr("percentile_approx(count, 0.95)").alias("p95_count"),
        F.max("count").alias("max_count"),
        F.count("*").alias("num_keys")
    ).collect()[0]
    
    skew_ratio = stats["max_count"] / stats["median_count"]
    
    print(f"Skew Analysis:")
    print(f"  Number of keys: {stats['num_keys']:,}")
    print(f"  Mean count: {stats['mean_count']:,.0f}")
    print(f"  Median count: {stats['median_count']:,.0f}")
    print(f"  P95 count: {stats['p95_count']:,.0f}")
    print(f"  Max count: {stats['max_count']:,}")
    print(f"  Skew ratio (max/median): {skew_ratio:.2f}")
    
    # Find top skewed keys
    top_skewed = key_counts.orderBy(F.desc("count")).limit(10)
    print("\nTop 10 skewed keys:")
    top_skewed.show()
    
    return {
        "is_skewed": skew_ratio > skew_threshold,
        "skew_ratio": skew_ratio,
        "top_keys": [row[key_col] for row in top_skewed.collect()]
    }

# Detect skew in left_df
skew_info = detect_skew(left_df, "key")
# Output:
# Skew ratio (max/median): 1000000.00
# Top skewed keys: ["hot_key", "key_12345", ...]

3. Salting Technique

# Salting: Add random prefix to distribute skewed keys
# Process: Split skewed join into salted + normal parts

def salted_join(left_df, right_df, join_key, salt_range=100):
    """Join with salting to handle skew."""
    
    # Detect skew
    skew_info = detect_skew(left_df, join_key)
    
    if not skew_info["is_skewed"]:
        return left_df.join(right_df, join_key)
    
    # Get skewed keys
    skewed_keys = skew_info["top_keys"]
    
    # Split data
    normal_left = left_df.filter(~F.col(join_key).isin(skewed_keys))
    skewed_left = left_df.filter(F.col(join_key).isin(skewed_keys))
    
    # Add salt to skewed data
    skewed_left_salted = skewed_left.withColumn(
        "salt", (F.rand() * salt_range).cast("int")
    ).withColumn(
        "salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
    )
    
    # Replicate right side for salted keys
    salt_sequence = F.sequence(F.lit(0), F.lit(salt_range - 1))
    right_salted = right_df.filter(
        F.col(join_key).isin(skewed_keys)
    ).crossJoin(
        salt_sequence.alias("salt")
    ).withColumn(
        "salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
    )
    
    # Normal join (no skew)
    normal_result = normal_left.join(right_df, join_key)
    
    # Salted join (handles skew)
    skewed_result = skewed_left_salted.join(
        right_salted, "salted_key"
    ).drop("salt", "salted_key")
    
    # Union results
    return normal_result.unionByName(skewed_result, allowMissingColumns=True)

# Apply salted join
result = salted_join(left_df, right_df, "key")
# Salting mathematical analysis:
# Let S = skew factor, P = partitions, R = salt range
#
# Without salt:
#   Skewed partition size: N Γ— S / P
#   Other partitions: N Γ— (1 - S) / P
#
# With salt (R = 100):
#   Skewed key distributed across R partitions
#   Each salted partition: N Γ— S / (P Γ— R)
#   Other partitions unchanged
#
# Improvement: S / R reduction in skewed partition size
# For S = 1000, R = 100: 10x reduction

# Salt range selection:
# R should be large enough to distribute skew
# But not too large (increases shuffle and memory)
#
# Rule of thumb: R = max(10, S / 10)
# For S = 1000: R = 100
# For S = 100: R = 10

4. AQE Skew Join (Automatic)

# Adaptive Query Execution (AQE) handles skew automatically
# No manual salting required!

# Configuration:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")

# How AQE skew join works:
# 1. Execute first stage (shuffle)
# 2. Collect partition statistics
# 3. Detect skewed partitions (size > factor Γ— median)
# 4. Split skewed partitions into sub-partitions
# 5. Replicate other side to match
# 6. Execute join with balanced partitions

# AQE skew detection algorithm:
# 1. Calculate median partition size: median = P50(partition_sizes)
# 2. For each partition: if size > factor Γ— median β†’ skewed
# 3. Split skewed partition into ceil(size / median) sub-partitions
# 4. Add random prefix to keys in sub-partitions
# 5. Replicate matching rows from other side

# AQE vs manual salting:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 180" width="100%" style={{ maxWidth: 700 }} xmlns="http://www.w3.org/2000/svg">
  <defs>
    <linearGradient id="skew-hdr" x1="0" y1="0" x2="1" y2="1">
      <stop offset="0%" stopColor="#6366f1"/>
      <stop offset="100%" stopColor="#4f46e5"/>
    </linearGradient>
    <filter id="skew-shadow">
      <feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
    </filter>
  </defs>
  <rect x="10" y="10" width="780" height="160" rx="14" fill="#fff" filter="url(#skew-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
  <rect x="10" y="10" width="780" height="30" rx="14" fill="url(#skew-hdr)"/>
  <rect x="10" y="24" width="780" height="16" fill="url(#skew-hdr)"/>
  <text x="160" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Aspect</text>
  <text x="400" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">AQE Skew Join</text>
  <text x="640" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Manual Salting</text>
  <text x="160" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Detection</text>
  <text x="400" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Automatic</text>
  <text x="640" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Manual analysis</text>
  <line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="160" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Implementation</text>
  <text x="400" y="80" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Zero code change</text>
  <text x="640" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Custom code</text>
  <line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="160" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Salt range</text>
  <text x="400" y="104" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Adaptive</text>
  <text x="640" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Fixed</text>
  <line x1="30" y1="114" x2="770" y2="114" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="160" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Performance</text>
  <text x="400" y="128" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">2-10x improvement</text>
  <text x="640" y="128" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">5-100x improvement</text>
  <line x1="30" y1="138" x2="770" y2="138" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="160" y="152" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Compatibility</text>
  <text x="400" y="152" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Spark 3.x only</text>
  <text x="640" y="152" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">All versions</text>
  <line x1="30" y1="162" x2="770" y2="162" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="160" y="168" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Edge cases</text>
  <text x="400" y="168" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">May miss some</text>
  <text x="640" y="168" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Handles all</text>
</svg>
</div>

# Example: AQE skew join
result = left_df.join(right_df, "key")  # AQE handles skew automatically

# Verify in Spark UI:
# Stage tab β†’ Show Details β†’ Look for "Skew Join" in plan

5. Isolate and Process

# Isolate skewed keys and process separately
# Best for: known hot keys (e.g., NULL, default values)

def isolate_skewed_keys(df, key_col, skewed_keys):
    """Split data into normal and skewed parts."""
    normal_df = df.filter(~F.col(key_col).isin(skewed_keys))
    skewed_df = df.filter(F.col(key_col).isin(skewed_keys))
    return normal_df, skewed_df

# Example: process NULL keys separately
skewed_keys = [None, "unknown", "default"]

normal_left, skewed_left = isolate_skewed_keys(left_df, "key", skewed_keys)
normal_right, skewed_right = isolate_skewed_keys(right_df, "key", skewed_keys)

# Process normally (no skew)
normal_result = normal_left.join(normal_right, "key")

# Process skewed with special handling
# Option 1: Broadcast small skewed side
if skewed_right.count() < 1000000:  # Small enough to broadcast
    skewed_result = skewed_left.join(
        F.broadcast(skewed_right), "key"
    )
else:
    # Option 2: Repartition with more partitions
    skewed_result = skewed_left.repartition(1000, "key") \
        .join(skewed_right.repartition(1000, "key"), "key")

# Union results
final_result = normal_result.unionByName(skewed_result, allowMissingColumns=True)

6. Bucketing for Repeated Joins

# Bucketing: Pre-partition data by key at write time
# Eliminates shuffle for repeated joins on same key

# Write bucketed tables
left_df.write \
    .bucketBy(200, "key") \
    .sortBy("key") \
    .mode("overwrite") \
    .saveAsTable("left_bucketed")

right_df.write \
    .bucketBy(200, "key") \
    .sortBy("key") \
    .mode("overwrite") \
    .saveAsTable("right_bucketed")

# Join bucketed tables (no shuffle!)
left_bucketed = spark.table("left_bucketed")
right_bucketed = spark.table("right_bucketed")
result = left_bucketed.join(right_bucketed, "key")

# Bucketing mathematical analysis:
# Let P = bucket count, N = rows
# Each bucket: N / P rows (if uniformly distributed)
# Join: local join within each bucket (no shuffle)
#
# Bucketing vs shuffle join:
# Shuffle join: O(N + M) network I/O
# Bucketed join: O(0) network I/O
#
# Bucketing vs salting:
# Salting: runtime salt addition, temporary
# Bucketing: write-time distribution, permanent

# Bucketing best practices:
# 1. Use same bucket count for both tables
# 2. Use same bucket key for both tables
# 3. Bucket count should be β‰₯ executor cores for parallelism
# 4. Target bucket size: 128-256 MB

# Bucket count calculation:
num_executors = 50
cores_per_executor = 4
total_cores = num_executors * cores_per_executor  # 200

# Use total_cores as bucket count for full parallelism
# Or use data_size / target_bucket_size
data_size_gb = 100
target_bucket_mb = 128
optimal_buckets = int(data_size_gb * 1024 / target_bucket_mb)  # 800

# Use max(total_cores, optimal_buckets) for both parallelism and size
bucket_count = max(total_cores, optimal_buckets)  # 800

7. Hybrid Approaches

# Combine multiple strategies for complex skew scenarios

def hybrid_skew_solution(left_df, right_df, join_key):
    """Hybrid approach: AQE + salting + isolation."""
    
    # Step 1: Enable AQE for automatic handling
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
    
    # Step 2: Isolate known hot keys
    hot_keys = [None, "unknown"]
    normal_left = left_df.filter(~F.col(join_key).isin(hot_keys))
    hot_left = left_df.filter(F.col(join_key).isin(hot_keys))
    
    normal_right = right_df.filter(~F.col(join_key).isin(hot_keys))
    hot_right = right_df.filter(F.col(join_key).isin(hot_keys))
    
    # Step 3: Process normal data (AQE handles remaining skew)
    normal_result = normal_left.join(normal_right, join_key)
    
    # Step 4: Process hot keys with salting
    if hot_left.count() > 0:
        # Add salt to hot keys
        salt_range = 100
        hot_left_salted = hot_left.withColumn(
            "salt", (F.rand() * salt_range).cast("int")
        ).withColumn(
            "salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
        )
        
        # Replicate right side
        salt_sequence = F.sequence(F.lit(0), F.lit(salt_range - 1))
        hot_right_salted = hot_right.crossJoin(
            salt_sequence.alias("salt")
        ).withColumn(
            "salted_key", F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
        )
        
        # Salted join
        hot_result = hot_left_salted.join(
            hot_right_salted, "salted_key"
        ).drop("salt", "salted_key")
        
        # Union all results
        return normal_result.unionByName(hot_result, allowMissingColumns=True)
    
    return normal_result

# Apply hybrid solution
result = hybrid_skew_solution(left_df, right_df, "key")

⚠️Common Pitfall

Salting adds overhead for non-skewed data. Always check if skew exists before applying salting. Use skew detection function to determine if intervention is needed.

πŸ’‘Interview Tip

When discussing skew solutions, always mention that AQE is the first line of defense in Spark 3.x. Manual salting is for cases where AQE doesn't handle skew well (e.g., extreme skew, complex join patterns).


Summary

StrategyWhen to UseImplementationImprovement
AQE Skew JoinDefault for Spark 3.xZero config2-10x
SaltingExtreme skew, AQE insufficientManual code5-100x
Isolate KeysKnown hot keysSplit + special handling10-100x
BucketingRepeated joins on same keyWrite-time distributionEliminates shuffle
HybridComplex scenariosCombine strategiesMaximum improvement

The key to solving data skew is: detect first, then apply the simplest solution that works. Start with AQE, escalate to salting if needed, and use bucketing for repeated operations.

Advertisement