PySpark Advanced Interview Series
Module 02: RDDs β The Foundation of Distributed Computing
Interview Question
"At Meta, we use RDDs for low-level distributed computing tasks where DataFrame abstractions don't suffice. Explain the difference between narrow and wide transformations, how RDD lineage enables fault tolerance, and demonstrate a custom partitioner implementation for a graph processing use case." β Meta Data Engineering Interview
"At Netflix, we process viewing activity streams using RDDs for complex window operations. Walk us through RDD persistence levels, how checkpointing breaks lineage, and the performance implications of using mapPartitions vs map." β Netflix Senior Data Engineer Interview
What is an RDD?
An RDD (Resilient Distributed Dataset) is Spark's fundamental data structure β an immutable, partitioned collection of elements that can be operated on in parallel. RDDs are the low-level building blocks upon which DataFrames and Datasets are built.
Key Properties
- Resilient: Fault-tolerant through lineage
- Distributed: Data is split across cluster nodes
- Dataset: Collection of partitioned data with primitives
RDD vs DataFrame vs Dataset
| Feature | RDD | DataFrame | Dataset |
|---|---|---|---|
| Type Safety | Yes | No (schema at runtime) | Yes |
| Optimization | None | Catalyst + Tungsten | Catalyst + Tungsten |
| API | Functional | Declarative | Declarative + Functional |
| Serialization | Java/Kryo | Tungsten binary | Tungsten binary |
| Use Case | Complex custom logic | Structured data processing | Typed structured data |
Creating RDDs
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("RDDInterview").getOrCreate()
sc = spark.sparkContext
# From a Python collection
rdd = sc.parallelize([1, 2, 3, 4, 5], numSlices=3)
# From an external file
rdd = sc.textFile("s3a://bucket/logs/*.txt")
# From a sequence file
rdd = sc.sequenceFile("hdfs://path/to/sequence-file")
# From another RDD
rdd2 = rdd.map(lambda x: x * 2)
# Empty RDD
empty_rdd = sc.emptyRDD()
# Union of RDDs
rdd_a = sc.parallelize([1, 2, 3])
rdd_b = sc.parallelize([4, 5, 6])
combined = rdd_a.union(rdd_b)
Narrow vs Wide Transformations
Narrow Transformations
Each input partition contributes to at most one output partition. No data movement across partitions (no shuffle).
# Narrow transformations β no shuffle
rdd = sc.parallelize(range(1, 1001), 10)
# map: 1-to-1 mapping
squared = rdd.map(lambda x: x ** 2)
# filter: selects subset
evens = rdd.filter(lambda x: x % 2 == 0)
# flatMap: 1-to-many mapping
words = rdd.flatMap(lambda x: [(x, 1), (x, 0)])
# mapPartitions: applies function to each partition
def process_partition(iterator):
total = sum(iterator)
yield total
partition_sums = rdd.mapPartitions(process_partition)
# mapPartitionsWithIndex: includes partition index
def process_with_index(idx, iterator):
for val in iterator:
yield f"partition_{idx}_{val}"
indexed = rdd.mapPartitionsWithIndex(process_with_index)
Wide Transformations
Each input partition can contribute to multiple output partitions. Requires data movement across the network (shuffle).
# Wide transformations β require shuffle
rdd = sc.parallelize(range(1, 1001), 10)
# reduceByKey: aggregate by key (shuffle)
keyed_rdd = rdd.map(lambda x: (x % 10, x))
reduced = keyed_rdd.reduceByKey(lambda a, b: a + b)
# groupByKey: group values by key (expensive!)
grouped = keyed_rdd.groupByKey()
# sortByKey: sort by key
sorted_rdd = keyed_rdd.sortByKey()
# join: inner join two RDDs
rdd_a = sc.parallelize([(1, "a"), (2, "b"), (3, "c")])
rdd_b = sc.parallelize([(1, 10), (2, 20), (3, 30)])
joined = rdd_a.join(rdd_b)
# distinct: remove duplicates
duplicated = sc.parallelize([1, 1, 2, 2, 3, 3])
unique = duplicated.distinct()
# repartition: changes number of partitions (full shuffle)
repartitioned = rdd.repartition(20)
# coalesce: reduces partitions without full shuffle
coalesced = rdd.coalesce(5)
β οΈMeta Interview Warning
At Meta, interviewers penalize candidates who use groupByKey() without justification. Always prefer reduceByKey() or aggregateByKey() because they reduce data before shuffling. groupByKey() shuffles ALL values to the driver, which can cause OOM.
Actions: Triggering Execution
Actions trigger computation and return results to the driver or write to storage.
rdd = sc.parallelize(range(1, 1001), 10)
# Return all elements as a list to the driver
all_data = rdd.collect() # WARNING: can OOM driver with large data
# Return first n elements
first_10 = rdd.take(10)
# Return a single element
first = rdd.first()
# Count elements
count = rdd.count()
# Aggregate
total = rdd.reduce(lambda a, b: a + b)
# Save to external storage
rdd.saveAsTextFile("s3a://bucket/output/")
# foreach: apply function to each element (side effects)
rdd.foreach(lambda x: print(x))
# Count by key
keyed_rdd = rdd.map(lambda x: (x % 10, x))
counts = keyed_rdd.countByKey()
# Collect as a dictionary
as_dict = keyed_rdd.collectAsMap()
# Take ordered
top_5 = rdd.takeOrdered(5, key=lambda x: -x)
RDD Lineage and Fault Tolerance
What is Lineage?
Every RDD maintains a lineage β a record of all transformations used to build it. If a partition is lost, Spark recomputes only that partition using the lineage graph.
# Build a pipeline
rdd = sc.parallelize(range(1, 1001), 10)
mapped = rdd.map(lambda x: x * 2)
filtered = mapped.filter(lambda x: x > 100)
reduced = filtered.reduce(lambda a, b: a + b)
# View lineage
print(mapped.toDebugString().decode())
Output:
(10) PythonRDD[2] at reduce at <stdin>:15 []
| MapPartitionsRDD[1] at map at <stdin>:13 []
| ParallelCollectionRDD[0] at parallelize at <stdin>:11 []
Checkpointing: Breaking Lineage
For long lineage chains, checkpointing saves RDD to reliable storage (HDFS/S3), breaking the lineage and preventing stack overflow on repeated recomputations.
# Enable checkpointing
sc.setCheckpointDir("s3a://bucket/checkpoints/")
# Build a deep pipeline
rdd = sc.parallelize(range(1, 10001))
for i in range(100):
rdd = rdd.map(lambda x, i=i: x + i)
# Checkpoint breaks lineage here
rdd.checkpoint()
rdd.count() # Forces materialization
# After checkpoint, lineage is truncated
print(rdd.toDebugString().decode())
βΉοΈNetflix Interview Insight
At Netflix, checkpointing is critical for iterative algorithms (like recommendation model training). Without it, lineage chains grow linearly with iterations, causing Spark to recompute the entire chain from scratch on failure. Checkpointing limits recomputation cost.
Real-World Scenario: Meta Social Graph Processing
Problem Statement
Compute mutual friends between users in a social network using RDDs. For each pair of users (A, B), find users who are friends with both A and B.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MutualFriends").getOrCreate()
sc = spark.sparkContext
# Load edges: (user_id, friend_id)
edges = sc.textFile("s3a://social-graph/edges/")
edge_pairs = edges.map(lambda line: line.split("\t")) \
.map(lambda parts: (int(parts[0]), int(parts[1])))
# Generate directed paths: A->B and A->C means B and C are connected through A
def generate_paths(user_friends):
user = user_friends[0]
friends = user_friends[1]
paths = []
for i in range(len(friends)):
for j in range(i + 1, len(friends)):
# Both (B,C) and (C,B) through A
paths.append(((friends[i], friends[j]), user))
paths.append(((friends[j], friends[i]), user))
return paths
# Group friends by user
user_friends = edge_pairs.groupByKey().mapValues(list)
# Generate all paths
mutual_candidates = user_friends.flatMap(generate_paths)
# Group by pair and intersect friend lists
mutual_friends = mutual_candidates \
.groupByKey() \
.mapValues(list) \
.filter(lambda x: len(x[1]) >= 2)
# Take sample results
for pair, friends in mutual_friends.take(10):
print(f"Users {pair[0]} and {pair[1]} share friends: {friends}")
spark.stop()
Performance: mapPartitions vs map
mapPartitions is significantly more efficient than map when the transformation involves expensive initialization (database connections, model loading).
# INEFFICIENT: opens connection for every element
def process_element(x):
conn = create_db_connection() # New connection per element!
result = conn.query(f"SELECT * WHERE id = {x}")
conn.close()
return result
rdd.map(process_element).collect()
# EFFICIENT: opens connection once per partition
def process_partition(iterator):
conn = create_db_connection() # One connection per partition
results = []
for x in iterator:
result = conn.query(f"SELECT * WHERE id = {x}")
results.append(result)
conn.close()
return iter(results)
rdd.mapPartitions(process_partition).collect()
π‘Amazon Pro Tip
When processing millions of records, mapPartitions reduces overhead by orders of magnitude. For a 10M record dataset with 200 partitions, you make 200 connections instead of 10M. This alone can reduce runtime from hours to minutes.
Persistence Levels
| Level | Storage | CPU Cost | Use Case |
|---|---|---|---|
NONE | No caching | None | One-time use |
DISK_ONLY | Disk only | Low | Large datasets that don't fit memory |
DISK_ONLY_2 | Disk (2 replicas) | Low | Fault-tolerant disk caching |
MEMORY_ONLY | Memory | Medium | Fits in memory, fast access |
MEMORY_ONLY_2 | Memory (2 replicas) | Medium | Fault-tolerant memory caching |
MEMORY_ONLY_SER | Memory (serialized) | High | Save memory, deserialize on access |
MEMORY_AND_DISK | Memory + spill to disk | Medium | Default, spills if memory full |
MEMORY_AND_DISK_2 | Memory + disk (2 replicas) | Medium | Fault-tolerant memory+disk |
OFF_HEAP | Off-heap memory | Low | Avoid GC overhead |
from pyspark import StorageLevel
# Cache in memory
rdd.cache() # Same as rdd.persist(StorageLevel.MEMORY_ONLY)
# Persistent to disk
rdd.persist(StorageLevel.DISK_ONLY)
# Serialized in memory (saves space, CPU tradeoff)
rdd.persist(StorageLevel.MEMORY_ONLY_SER)
# Fault-tolerant with 2 replicas
rdd.persist(StorageLevel.MEMORY_AND_DISK_2)
# Unpersist when done
rdd.unpersist()
Custom Partitioner
from pyspark import Partitioner
class HashPartitioner(Partitioner):
def __init__(self, num_partitions):
super().__init__(num_partitions)
def __call__(self, key):
return hash(key) % self.numPartitions
# Apply custom partitioner
keyed_rdd = sc.parallelize([
("user:1001", "click"),
("user:1002", "view"),
("user:1003", "purchase"),
("user:1001", "scroll")
])
partitioned = keyed_rdd.partitionBy(
numPartitions=4,
partitionFunc=lambda key: hash(key) % 4
)
# Verify partitions
print(f"Number of partitions: {partitioned.getNumPartitions()}")
# Use mapPartitions to process each partition independently
def analyze_partition(iterator):
records = list(iterator)
yield {
"count": len(records),
"keys": [r[0] for r in records]
}
partition_stats = partitioned.mapPartitions(analyze_partition)
for stat in partition_stats.collect():
print(stat)
Edge Cases and Gotchas
1. Driver OOM from collect()
# DANGEROUS: brings all data to driver
all_data = rdd.collect() # Can crash driver with large RDDs
# SAFE: process in partitions
for partition in rdd.glom().collect():
process(partition)
2. Serialization Errors with PySpark RDDs
# INEFFICIENT: Python objects serialized through Py4J
rdd = sc.parallelize([MyClass() for _ in range(1000)])
# BETTER: Use DataFrames for complex objects
df = spark.createDataFrame(data, schema)
3. Shuffle Partition Explosion
# BAD: Creates too many partitions
rdd = sc.parallelize(range(1000000))
rdd = rdd.repartition(1000000) # 1M partitions!
# GOOD: Balanced partitioning
rdd = rdd.repartition(1000) # Reasonable
β οΈCommon Pitfall
RDDs in PySpark serialize Python objects through Py4J, making them significantly slower than DataFrames which use Tungsten's binary format. Use DataFrames for structured data; reserve RDDs for unstructured or complex custom logic.
When to Use RDDs Over DataFrames
- Unstructured data: Raw text processing without clear schema
- Complex custom logic: Transformations that don't map to SQL/DataFrame API
- Low-level control: Need explicit partitioning or data placement
- Iterative algorithms: Graph processing, ML algorithms with custom iteration
- Legacy codebases: Existing RDD-based code
Summary
RDDs are Spark's foundational abstraction. While DataFrames are preferred for most structured data workloads, understanding RDDs deeply is crucial for interviews at Meta and Netflix where complex distributed computing problems require low-level control. Master the distinction between narrow and wide transformations, understand lineage-based fault tolerance, and know when to use mapPartitions over map.