PySpark Advanced Interview Series
Module 09: Catalyst & Tungsten β The Optimization Engine
Interview Question
"At Google, we optimize query performance by understanding Catalyst's optimization phases. Walk us through the complete query optimization pipeline from logical plan to physical execution. How does predicate pushdown work, and what role does the cost-based optimizer play?" β Google Data Engineer Interview
"At Netflix, we tune Spark queries for petabyte-scale workloads. Explain how Tungsten's whole-stage code generation works, how you would read and interpret an EXPLAIN plan, and what techniques you use to identify optimization opportunities." β Netflix Senior Data Engineer Interview
Catalyst Optimizer Pipeline
Catalyst is Spark's extensible query optimization framework. It transforms queries through four phases:
SQL/DataFrame API
β
Parsed Logical Plan
β
Analyzed Logical Plan
β
Optimized Logical Plan
β
Physical Plans (multiple)
β
Selected Physical Plan
β
RDD Code (via Tungsten)
Phase 1: Analysis
Catalyst resolves references and validates the query:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as _sum
spark = SparkSession.builder.appName("CatalystInterview").getOrCreate()
# Read data
sales = spark.read.parquet("s3a://bucket/sales/")
products = spark.read.parquet("s3a://bucket/products/")
# This query will be analyzed
result = sales \
.join(products, "product_id") \
.filter(col("category") == "Electronics") \
.groupBy("product_name") \
.agg(_sum("revenue").alias("total_revenue"))
# View the analyzed logical plan
result.explain("extended")
The analyzer:
- Resolves column references (e.g.,
categoryβproducts.category) - Validates data types
- Checks for ambiguous references
- Fills in implicit defaults (e.g., join type)
Phase 2: Optimization (Logical)
Catalyst applies rule-based and cost-based optimizations:
Predicate Pushdown
# Catalyst pushes filters closer to data source
result = sales \
.filter(col("date") == "2024-01-01") \
.join(products, "product_id") \
.filter(col("category") == "Electronics")
# Catalyst rewrites to:
# 1. Filter sales by date (pushed to scan)
# 2. Filter products by category (pushed to scan)
# 3. Join filtered results
# Verify with explain
result.explain(True)
Column Pruning
# Catalyst removes unused columns from scans
result = sales \
.select("product_id", "revenue") \
.join(products, "product_id") \
.select("product_name", "revenue")
# Catalyst prunes unused columns from both tables
# Only reads product_id, revenue from sales
# Only reads product_id, product_name from products
Constant Folding
# Catalyst evaluates constant expressions at compile time
result = sales.withColumn(
"adjusted_revenue",
col("revenue") * (1 + 0.1) # Constant: 1.1
)
# Catalyst rewrites to:
# result.withColumn("adjusted_revenue", col("revenue") * 1.1)
Join Reordering
# Catalyst reorders joins based on table sizes and selectivity
result = large_table \
.join(medium_table, "id") \
.join(small_table, "id")
# Catalyst may reorder to:
# small_table.join(medium_table, "id").join(large_table, "id")
# Processing small table first reduces intermediate result size
Phase 3: Physical Planning
Catalyst generates multiple physical plans and selects the best one:
# View all physical plans
result.explain("formatted")
# Output includes:
# == Physical Plan ==
# *(2) Project [product_name, revenue]
# +- *(2) BroadcastHashJoin [product_id], [product_id], BuildLeft
# :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, false]))
# | +- *(1) Filter (category = Electronics)
# | +- *(1) ColumnarToRow
# | +- ParquetScan [product_id, product_name, category]
# +- *(2) Filter (date = 2024-01-01)
# +- *(2) ColumnarToRow
# +- ParquetScan [product_id, revenue, date]
Physical Plan Types
| Plan | Description | When Used |
|---|---|---|
| FileScan | Read from storage | Base scan |
| Filter | Apply predicate | Always |
| Project | Select columns | Always |
| BroadcastHashJoin | Broadcast small table | Small table < threshold |
| SortMergeJoin | Sort and merge | Large-large join |
| ShuffleHashJoin | Hash join after shuffle | Medium-large join |
| Sort | Sort data | ORDER BY, GROUP BY |
| Aggregate | Compute aggregations | GROUP BY, aggregations |
Phase 4: Code Generation (Tungsten)
Tungsten generates optimized JVM bytecode at runtime:
# Tungsten optimizations:
# 1. Whole-stage code generation (fuses multiple operations)
# 2. Cache-aware computation (uses CPU cache efficiently)
# 3. Off-heap memory management (reduces GC pressure)
# Verify code generation in explain plan
result.explain("codegen")
# Output shows generated Java code
# == Generated Code ==
# class GeneratedClass {
# public Object generate(Object[] references) {
# return new CodeIterator(references);
# }
# }
βΉοΈGoogle Interview Insight
At Google, understanding Tungsten's code generation is crucial. Whole-stage code generation fuses multiple operators (filter + project + aggregate) into a single generated function, eliminating virtual function calls and improving CPU efficiency by 2-10x.
Reading EXPLAIN Plans
# Simple explain
df.explain()
# Extended explain (shows all phases)
df.explain("extended")
# Formatted explain (most readable)
df.explain("formatted")
# Codegen explain (shows generated code)
df.explain("codegen")
# Cost-based explain (shows statistics)
df.explain("cost")
Plan Reading Example
from pyspark.sql.functions import col, sum as _sum, broadcast
result = sales \
.filter(col("date") >= "2024-01-01") \
.select("product_id", "revenue") \
.join(broadcast(products), "product_id") \
.groupBy("category") \
.agg(_sum("revenue").alias("total_revenue"))
result.explain("formatted")
Output:
== Physical Plan ==
*(3) HashAggregate(keys=[category], functions=[sum(revenue)])
+- *(3) BroadcastHashJoin [product_id], [product_id], BuildRight
:- *(3) HashAggregate(keys=[product_id], functions=[sum(revenue)])
| +- *(3) Project [product_id, revenue]
| +- *(3) Filter (date >= 2024-01-01)
| +- *(3) ColumnarToRow
| +- ParquetScan [product_id, revenue, date]
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, false]))
+- *(1) Filter (category IS NOT NULL)
+- *(1) ColumnarToRow
+- ParquetScan [product_id, category]
Reading this plan:
- Innermost: Parquet scans read data
- Filter: applies date predicate
- Project: selects columns
- Aggregate: groups by product_id
- Broadcast: sends products to all executors
- Join: joins aggregated sales with products
- Final aggregate: groups by category
Real-World Scenario: Google Query Optimization
Problem Statement
Optimize a complex analytics query that was running in 45 minutes. The query joins 3 tables, applies multiple filters, and computes complex aggregations.
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark import StorageLevel
spark = SparkSession.builder \
.appName("GoogleQueryOptimization") \
.config("spark.sql.shuffle.partitions", "500") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.autoBroadcastJoinThreshold", str(100 * 1024 * 1024)) \
.getOrCreate()
# === ORIGINAL SLOW QUERY (45 minutes) ===
def slow_query():
# Read all data without pruning
orders = spark.read.parquet("s3a://google-data/orders/")
customers = spark.read.parquet("s3a://google-data/customers/")
products = spark.read.parquet("s3a://google-data/products/")
# Complex query without optimization hints
result = orders \
.join(customers, "customer_id") \
.join(products, "product_id") \
.filter(col("order_date") >= "2024-01-01") \
.filter(col("status") == "completed") \
.groupBy("category", "region") \
.agg(
sum("amount").alias("total_revenue"),
count("*").alias("order_count"),
avg("amount").alias("avg_order_value")
)
return result
# === OPTIMIZED QUERY (3 minutes) ===
def optimized_query():
# 1. Read only needed columns (column pruning)
orders = spark.read.parquet("s3a://google-data/orders/") \
.select("order_id", "customer_id", "product_id", "amount", "status", "order_date")
customers = spark.read.parquet("s3a://google-data/customers/") \
.select("customer_id", "region")
products = spark.read.parquet("s3a://google-data/products/") \
.select("product_id", "category")
# 2. Filter early (predicate pushdown)
orders = orders \
.filter(col("order_date") >= "2024-01-01") \
.filter(col("status") == "completed")
# 3. Cache filtered orders (used multiple times)
orders.cache()
orders.count() # Force materialization
# 4. Broadcast small tables
# 5. Repartition for balanced joins
orders = orders.repartition(500, "product_id")
# 6. Optimized join order
result = orders \
.join(broadcast(products), "product_id") \
.join(broadcast(customers), "customer_id") \
.groupBy("category", "region") \
.agg(
sum("amount").alias("total_revenue"),
count("*").alias("order_count"),
avg("amount").alias("avg_order_value")
)
# 7. Cache result for downstream use
result.cache()
# Clean up
orders.unpersist()
return result
# Compare performance
import time
start = time.time()
slow_result = slow_query()
slow_result.count()
slow_time = time.time() - start
print(f"Slow query: {slow_time:.0f}s")
start = time.time()
fast_result = optimized_query()
fast_result.count()
fast_time = time.time() - start
print(f"Optimized query: {fast_time:.0f}s")
# View optimized plan
fast_result.explain("formatted")
spark.stop()
AQE (Adaptive Query Execution)
Spark 3.0+ introduces runtime query optimization:
# Enable AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")
# AQE features:
# 1. Auto coalesce shuffle partitions
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
# 2. Optimize sort-merge join to broadcast join
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
# 3. Handle skewed joins
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# 4. Optimize skew join
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")
How AQE Works
# AQE re-optimizes at runtime based on actual data statistics:
# 1. After shuffle, it knows actual partition sizes
# 2. It can merge small partitions
# 3. It can split large partitions
# 4. It can change join strategies based on actual table sizes
# Example: Sort-merge join converted to broadcast join
# Before AQE: Both sides shuffled (expensive)
# After AQE: If one side is small, converted to broadcast join (fast)
# Verify AQE decisions
result.explain("formatted")
# Look for: "AdaptiveSparkPlan" in the plan
Performance Analysis Checklist
π‘Query Optimization Checklist
- Check EXPLAIN plan for unnecessary shuffles
- Verify predicate pushdown (filters near scan)
- Check column pruning (only needed columns read)
- Verify broadcast joins for small tables
- Check partition count (not too many/too few)
- Enable AQE for automatic optimization
- Cache DataFrames used multiple times
- Use bucketing for repeated joins
- Monitor Spark UI for stragglers
- Profile with Spark UI DAG visualization
Common Optimization Techniques
| Technique | When to Use | Expected Improvement |
|---|---|---|
| Predicate pushdown | Always | 2-10x |
| Column pruning | Always | 2-5x |
| Broadcast join | Small table < threshold | 2-100x |
| AQE | Always (Spark 3.0+) | 1.5-3x |
| Caching | DataFrame used > once | 2-10x |
| Bucketing | Repeated joins on same key | 2-5x |
| Repartition | Data skew, pre-join | 2-10x |
| Kryo serialization | Always | 2-10x |
Edge Cases
1. UDFs Prevent Optimization
# UDFs are black boxes β Catalyst can't optimize through them
@udf(returnType=DoubleType())
def my_udf(x):
return x * 2
# Catalyst can't push filter through this UDF
df.withColumn("result", my_udf(col("value"))) \
.filter(col("result") > 100) # Filter applied AFTER UDF
# Better: use built-in functions
df.withColumn("result", col("value") * 2) \
.filter(col("result") > 100) # Catalyst can optimize this
2. Skewed Data Prevents Optimization
# Catalyst assumes uniform distribution β skew breaks this
# Enable AQE skew handling
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# Or manually salt skewed keys
3. Dynamic Partition Pruning
# Catalyst prunes partitions based on WHERE clauses
# Ensure partition columns are used directly in filters
spark.sql("""
SELECT * FROM logs
WHERE date = '2024-01-01' -- Prunes to single partition
""")
# Avoid functions on partition columns
spark.sql("""
SELECT * FROM logs
WHERE to_date(timestamp) = '2024-01-01' -- No pruning!
""")
Summary
Understanding Catalyst and Tungsten is essential for optimizing Spark queries at Google and Netflix scale. The optimizer applies rule-based and cost-based transformations to convert your query into efficient physical plans. By writing query-friendly code (filtering early, using broadcast hints, enabling AQE), you give Catalyst more opportunities to optimize.