PySpark Advanced Interview Series
Module 07: Caching & Persistence β Trading Memory for Speed
Interview Question
"At Amazon, our ML pipelines reuse intermediate DataFrames across multiple model training runs. Walk us through all caching and persistence options in Spark, when to use each, and how you would debug cache eviction issues when working with datasets that exceed executor memory." β Amazon Senior Data Engineer Interview
"At Google, we optimize BigQuery-like workloads using Spark. Explain the difference between caching and persistence, how Spark manages memory for cached DataFrames, and what happens when a cached DataFrame is evicted due to memory pressure." β Google Data Engineer Interview
Caching vs Persistence
In Spark, cache() and persist() are used to store intermediate results to avoid recomputation.
from pyspark.sql import SparkSession
from pyspark import StorageLevel
spark = SparkSession.builder.appName("CachingInterview").getOrCreate()
# Read a large dataset
df = spark.read.parquet("s3a://bucket/large-dataset/")
# Cache = persist(StorageLevel.MEMORY_AND_DISK)
df.cache() # Equivalent to df.persist(StorageLevel.MEMORY_AND_DISK)
# Persist with specific storage level
df.persist(StorageLevel.MEMORY_ONLY)
df.persist(StorageLevel.MEMORY_AND_DISK)
df.persist(StorageLevel.DISK_ONLY)
# Unpersist (remove from cache)
df.unpersist()
Storage Levels
MEMORY_ONLY
Stores the DataFrame as deserialized Java objects in JVM heap. Fastest access but limited by memory.
df.persist(StorageLevel.MEMORY_ONLY)
# Pros: Fastest access, no deserialization overhead
# Cons: Limited by executor memory, no fault tolerance
# If a partition doesn't fit in memory, it won't be cached
# and will be recomputed on demand
MEMORY_AND_DISK
Default caching level. Stores in memory first, spills to disk if memory is insufficient.
df.persist(StorageLevel.MEMORY_AND_DISK)
# Pros: Handles datasets larger than memory, fault-tolerant
# Cons: Disk I/O for spilled partitions, slower than MEMORY_ONLY
# Spark stores as much as possible in memory, spills rest to local disk
DISK_ONLY
Stores everything on local disk. Useful when memory is very limited.
df.persist(StorageLevel.DISK_ONLY)
# Pros: No memory pressure, can cache very large datasets
# Cons: Disk I/O is slow, not fault-tolerant (local disk)
Serialized Variants
# Serialized variants save memory but cost CPU for deserialization
df.persist(StorageLevel.MEMORY_ONLY_SER)
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
# Serialized format is more compact (2-5x smaller than deserialized)
# But requires deserialization on each access
Replicated Variants
# Replicated variants store 2 copies for fault tolerance
df.persist(StorageLevel.MEMORY_ONLY_2)
df.persist(StorageLevel.MEMORY_AND_DISK_2)
df.persist(StorageLevel.DISK_ONLY_2)
# Useful when executors are unreliable
# If one executor fails, data can be recovered from another replica
OFF_HEAP
# Store outside JVM heap (avoid GC overhead)
df.persist(StorageLevel.OFF_HEAP)
# Pros: No garbage collection pauses, more memory available
# Cons: Requires explicit memory management, slower than on-heap
Storage Level Comparison
| Level | Memory | Disk | Serialized | Replicated | Use Case |
|---|---|---|---|---|---|
| MEMORY_ONLY | β | β | β | β | Fits in memory, fast access |
| MEMORY_ONLY_SER | β | β | β | β | Memory-constrained, batch processing |
| MEMORY_AND_DISK | β | β | β | β | Default, spills to disk |
| MEMORY_AND_DISK_SER | β | β | β | β | Large datasets, memory-limited |
| DISK_ONLY | β | β | β | β | Very large datasets |
| MEMORY_ONLY_2 | β | β | β | β | Fault-tolerant, fits in memory |
| MEMORY_AND_DISK_2 | β | β | β | β | Fault-tolerant, spills to disk |
| OFF_HEAP | β | β | β | β | Avoid GC overhead |
Memory Management
How Spark Manages Cached Data
# Spark divides executor memory into:
# 1. Reserved memory (300MB) - for system operations
# 2. User memory (40%) - for user data structures
# 3. Execution memory (30%) - for shuffles, joins, sorts
# 4. Storage memory (30%) - for cached data
# Configuration
spark.conf.set("spark.memory.fraction", "0.8") # Total usable memory fraction
spark.conf.set("spark.memory.storageFraction", "0.5") # Storage vs execution split
# Note: Execution and storage share memory (Unified Memory Management)
# Execution can borrow from storage when needed
# Storage can reclaim from execution when not in use
Memory Overhead
# Total container memory = spark.executor.memory + spark.executor.memoryOverhead
# Default overhead = max(384MB, 0.10 Γ spark.executor.memory)
# For a 16GB executor:
# Overhead = max(384MB, 1.6GB) = 1.6GB
# Total container = 17.6GB
# This affects how much data can be cached
# If executor has 16GB heap, ~30% (4.8GB) available for caching
# Actual cached data may be less due to fragmentation and overhead
β οΈAmazon Interview Warning
At Amazon, a common production issue is OOM errors when caching large DataFrames. The solution is not to increase memory but to:
- Use serialized storage (MEMORY_ONLY_SER)
- Partition data before caching
- Cache only needed columns (select before cache)
- Use OFF_HEAP to avoid GC pressure
When to Cache
Good Candidates for Caching
# 1. DataFrame used multiple times
df = spark.read.parquet("s3a://bucket/data/")
df.cache()
# Use 1: Filter and aggregate
result1 = df.filter(col("status") == "active") \
.groupBy("category").count()
# Use 2: Join with dimension
result2 = df.join(dim_df, "category") \
.select("name", "value")
# Use 3: Another transformation
result3 = df.withColumn("score", col("value") * 2)
# All three use the cached version (after first action triggers caching)
result1.show()
result2.show()
result3.show()
# Don't forget to unpersist when done
df.unpersist()
Bad Candidates for Caching
# 1. DataFrame used only once
df = spark.read.parquet("s3a://bucket/data/")
result = df.filter(col("status") == "active").count()
# No need to cache β used only once
# 2. DataFrame that changes frequently
# Caching stale data can lead to incorrect results
# 3. DataFrame larger than available memory (without disk spill)
# Will cause constant eviction and recomputation
Real-World Scenario: Amazon ML Pipeline Caching
Problem Statement
Build a machine learning feature engineering pipeline that computes 50+ features from raw data. The feature computation is expensive (joins, aggregations, window functions). Multiple models need the same features. Optimize with caching.
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark import StorageLevel
spark = SparkSession.builder \
.appName("AmazonMLCaching") \
.config("spark.executor.memory", "16g") \
.config("spark.memory.fraction", "0.8") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Read raw data
raw_data = spark.read.parquet("s3a://amazon-ml/user-events/")
user_profiles = spark.read.parquet("s3a://amazon-ml/user-profiles/")
product_catalog = spark.read.parquet("s3a://amazon-ml/product-catalog/")
# Step 1: Clean and standardize (cache this β used everywhere)
clean_data = raw_data \
.filter(col("event_type").isin(["view", "click", "purchase", "cart"])) \
.filter(col("timestamp").isNotNull()) \
.withColumn("event_date", to_date(col("timestamp"))) \
.withColumn("hour_of_day", hour(col("timestamp"))) \
.select("user_id", "product_id", "event_type", "event_date", "hour_of_day", "timestamp")
clean_data.persist(StorageLevel.MEMORY_AND_DISK)
clean_data.count() # Force materialization
# Step 2: Compute user-level features (cache for multiple models)
user_features = clean_data \
.groupBy("user_id") \
.agg(
count("*").alias("total_events"),
countDistinct("product_id").alias("unique_products"),
countDistinct("event_date").alias("active_days"),
sum(when(col("event_type") == "purchase", 1).otherwise(0)).alias("purchases"),
sum(when(col("event_type") == "view", 1).otherwise(0)).alias("views"),
sum(when(col("event_type") == "click", 1).otherwise(0)).alias("clicks"),
sum(when(col("event_type") == "cart", 1).otherwise(0)).alias("cart_adds"),
avg("hour_of_day").alias("avg_active_hour")
) \
.withColumn("conversion_rate", col("purchases") / col("views")) \
.withColumn("click_through_rate", col("clicks") / col("views")) \
.withColumn("cart_abandonment_rate",
when(col("cart_adds") > 0,
1 - (col("purchases") / col("cart_adds")))
.otherwise(0))
user_features.persist(StorageLevel.MEMORY_AND_DISK)
user_features.count() # Force materialization
# Step 3: Compute product-level features
product_features = clean_data \
.groupBy("product_id") \
.agg(
count("*").alias("total_views"),
countDistinct("user_id").alias("unique_viewers"),
sum(when(col("event_type") == "purchase", 1).otherwise(0)).alias("total_purchases"),
sum(when(col("event_type") == "cart", 1).otherwise(0)).alias("total_cart_adds")
) \
.withColumn("popularity_score", col("unique_viewers") / 1000) \
.withColumn("purchase_rate", col("total_purchases") / col("total_views"))
product_features.persist(StorageLevel.MEMORY_AND_DISK)
product_features.count()
# Step 4: Join for final feature matrix
feature_matrix = user_features \
.join(product_features, "product_id", "left") \
.join(user_profiles, "user_id", "left")
# Step 5: Model training (multiple models use same features)
# Model 1: Recommendation
from pyspark.ml.classification import LogisticRegression
lr_features = feature_matrix.select("user_id", "product_id", "total_events",
"conversion_rate", "popularity_score", "label")
lr_model = LogisticRegression().fit(lr_features)
# Model 2: Churn prediction (reuses cached user_features)
churn_features = user_features.select("user_id", "active_days", "total_events",
"conversion_rate")
churn_model = LogisticRegression().fit(churn_features)
# Clean up
clean_data.unpersist()
user_features.unpersist()
product_features.unpersist()
Cache Monitoring and Debugging
Check What's Cached
# List all cached DataFrames
for key, value in spark.catalog.listTables():
print(f"Table: {key}, Storage: {value}")
# Check if a specific DataFrame is cached
print(f"Is cached: {clean_data.is_cached}")
# Get cache info
storage_info = spark.sparkContext._jsc.sc().getExecutorMemoryStatus()
print(storage_info)
Spark UI Cache Tab
# Access cache information through Spark UI
# Navigate to: http://driver-node:4040/storage/
# Key metrics:
# - Storage Level: What type of caching
# - Size in Memory: How much memory used
# - Size on Disk: How much disk used for spilling
# - Partitions: Number of partitions cached
Cache Eviction
# When memory is full, Spark evicts cached data
# Priority: LRU (Least Recently Used)
# Monitor cache hits/misses through Spark UI
# Look for "Block Updated" count in Storage tab
# If you see frequent eviction:
# 1. Increase executor memory
# 2. Use serialized storage (MEMORY_ONLY_SER)
# 3. Partition data before caching
# 4. Cache only needed columns
βΉοΈGoogle Interview Insight
At Google, cache eviction is a common production issue. When working with datasets larger than cluster memory, Spark constantly evicts and recomputes partitions. The solution is to:
- Cache at the right granularity (after filtering, before expensive transformations)
- Use checkpointing instead of caching for very large datasets
- Monitor cache hit rates through Spark metrics
Cache vs Checkpoint vs Repartition
| Feature | Cache | Checkpoint | Repartition |
|---|---|---|---|
| Purpose | Avoid recomputation | Break lineage, save to storage | Change partition count |
| Storage | Memory/Disk (local) | Reliable storage (HDFS/S3) | Memory (shuffled) |
| Lineage | Preserved | Broken | Preserved |
| Fault Tolerance | Recompute | Read from storage | Recompute |
| Performance | Fast access | Slower (remote I/O) | One-time shuffle cost |
# Cache: fast access, lineage preserved
df.cache()
# Checkpoint: break long lineage chains
sc.setCheckpointDir("s3a://bucket/checkpoints/")
df.checkpoint()
# Repartition: change partition count for next operation
df.repartition(200)
Advanced Caching Patterns
Selective Column Caching
# Cache only needed columns to save memory
feature_df = raw_data.select(
"user_id", "product_id", "event_type", "timestamp"
)
feature_df.persist(StorageLevel.MEMORY_AND_DISK)
feature_df.count()
# Much smaller than caching all columns
Tiered Caching
# Cache hot data in memory, warm data on disk
hot_data = df.filter(col("date") >= "2024-01-01")
hot_data.persist(StorageLevel.MEMORY_ONLY)
warm_data = df.filter((col("date") >= "2023-01-01") & (col("date") < "2024-01-01"))
warm_data.persist(StorageLevel.DISK_ONLY)
cold_data = df.filter(col("date") < "2023-01-01")
# Don't cache cold data β read from storage on demand
Cache Invalidation
# Spark doesn't have automatic cache invalidation
# You must manually unpersist when data changes
# Pattern: Cache, use, unpersist
df.cache()
df.count() # Force materialization
# ... use df multiple times ...
df.unpersist()
# Pattern: Check and recache
if df.is_cached:
df.unpersist()
df.cache()
Performance Comparison
import time
# Without caching
start = time.time()
df = spark.read.parquet("s3a://bucket/large-data/")
result1 = df.filter(col("status") == "active").count()
result2 = df.groupBy("category").count().show()
result3 = df.join(dim_df, "id").count()
print(f"Without caching: {time.time() - start:.2f}s")
# With caching
start = time.time()
df = spark.read.parquet("s3a://bucket/large-data/")
df.cache()
df.count() # Force materialization
result1 = df.filter(col("status") == "active").count()
result2 = df.groupBy("category").count().show()
result3 = df.join(dim_df, "id").count()
df.unpersist()
print(f"With caching: {time.time() - start:.2f}s")
| Scenario | Without Cache | With Cache | Improvement |
|---|---|---|---|
| Single use | Same | Same | No benefit |
| 3 uses | 3x read time | 1x read + 2x cache | 2x faster |
| 10 uses | 10x read time | 1x read + 9x cache | 9x faster |
| Larger than memory | N/A | Disk spill | Slower than no cache |
Best Practices
π‘Production Caching Checklist
- Cache DataFrames used more than once
- Use MEMORY_AND_DISK as default (handles overflow)
- Use MEMORY_ONLY_SER for memory-constrained environments
- Always unpersist when done (don't rely on GC)
- Cache after filtering, before expensive transformations
- Monitor cache hit rates through Spark UI
- Don't cache datasets larger than cluster memory without disk spill
- Consider checkpointing for very long lineage chains
- Cache only needed columns (select before cache)
Summary
Caching is a powerful optimization tool in Spark, but it must be used judiciously. Caching too much wastes memory; caching too little wastes CPU on recomputation. Understanding storage levels, memory management, and cache monitoring helps you make informed decisions at Amazon and Google scale.