UDF Performance: Pandas UDF, Arrow, COALESCE vs UDF
Difficulty: Expert | Companies: Meta, Netflix, Uber, Airbnb, LinkedIn
βΉοΈInterview Context
UDF performance is critical because poorly written UDFs can be 100x slower than built-in functions. Interviewers test understanding of serialization overhead, vectorization, and when to use each UDF type.
Question
Compare Python UDFs, Pandas UDFs, and built-in Spark functions in terms of performance. How does Apache Arrow improve UDF performance? When should you use COALESCE instead of a UDF? Provide benchmarks and optimization strategies.
Detailed Answer
1. UDF Performance Comparison
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
import pandas as pd
spark = SparkSession.builder \
.appName("UDFPerformance") \
.config("spark.sql.execution.arrow.pyspark.enabled", "true") \
.config("spark.sql.execution.arrow.maxRecordsPerBatch", "10000") \
.getOrCreate()
# Test data
df = spark.range(10000000).withColumn(
"value", F.randn()
).withColumn(
"category", F.array(
F.lit("A"), F.lit("B"), F.lit("C")
)[F.floor(F.rand() * 3).cast("int")]
)
# Performance comparison:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 110" width="100%" style={{ maxWidth: 720 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="udf-hdr" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<filter id="udf-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
</filter>
</defs>
<rect x="10" y="10" width="780" height="90" rx="14" fill="#fff" filter="url(#udf-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
<rect x="10" y="10" width="780" height="30" rx="14" fill="url(#udf-hdr)"/>
<rect x="10" y="24" width="780" height="16" fill="url(#udf-hdr)"/>
<text x="140" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">UDF Type</text>
<text x="340" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Latency</text>
<text x="500" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Throughput</text>
<text x="680" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Serialization</text>
<text x="140" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Built-in Function</text>
<text x="340" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">1x (baseline)</text>
<text x="500" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Highest</text>
<text x="680" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">None (JVM)</text>
<line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="140" y="80" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Pandas UDF</text>
<text x="340" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">5-10x</text>
<text x="500" y="80" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">High</text>
<text x="680" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Arrow (zero-copy)</text>
<line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="140" y="100" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Python UDF</text>
<text x="340" y="100" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">100-1000x</text>
<text x="500" y="100" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Low</text>
<text x="680" y="100" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Pickle (slow)</text>
</svg>
</div>
2. Python UDFs (Row-at-a-Time)
# Python UDF: Process one row at a time
# Serialization: Row β Pickle β Python process β Pickle β Row
# Overhead: ~100-1000x slower than built-in functions
@F.udf(returnType=DoubleType())
def python_udf(value):
"""Process single row β slow due to serialization."""
return value * 2 + 1
# Apply Python UDF
result_py = df.withColumn("doubled", python_udf(F.col("value")))
# Performance analysis:
# For N rows:
# Serialization: O(N Γ row_size) β Pickle serialization
# Processing: O(N Γ computation) β Python execution
# Deserialization: O(N Γ row_size) β Pickle deserialization
# Total: O(N Γ (2 Γ serialization + computation))
# Serialization cost:
# Each row: Python object β Pickle bytes β Java object
# Pickle overhead: ~50-100 bytes per row (headers, type info)
# For 10M rows: 500MB - 1GB serialization overhead
# Optimization: Use list comprehension for multiple operations
@F.udf(returnType=ArrayType(DoubleType()))
def python_udf_batch(values):
"""Process batch of values β still limited by serialization."""
return [v * 2 + 1 for v in values]
# Still slow because serialization is the bottleneck
3. Pandas UDFs (Vectorized)
# Pandas UDF: Process batches of data using Pandas/NumPy
# Serialization: Arrow (zero-copy between JVM and Python)
# Performance: 5-10x faster than Python UDF
# Scalar Pandas UDF (one-to-one)
@F.pandas_udf(DoubleType())
def pandas_udf_scalar(value: pd.Series) -> pd.Series:
"""Process column vector β Arrow serialization."""
return value * 2 + 1
# Apply Scalar Pandas UDF
result_pd = df.withColumn("doubled", pandas_udf_scalar(F.col("value")))
# Grouped Map Pandas UDF (many-to-one)
@F.pandas_udf(
StructType([
StructField("category", StringType()),
StructField("avg_value", DoubleType()),
StructField("count", LongType())
]),
functionType=F.PandasUDFType.GROUPED_MAP
)
def pandas_udf_grouped(pdf: pd.DataFrame) -> pd.DataFrame:
"""Process entire group β Arrow serialization."""
return pd.DataFrame({
"category": [pdf["category"].iloc[0]],
"avg_value": [pdf["value"].mean()],
"count": [len(pdf)]
})
# Apply Grouped Map
result_grouped = df.groupBy("category").apply(pandas_udf_grouped)
# Aggregate Pandas UDF (many-to-one)
@F.pandas_udf(DoubleType(), F.PandasUDFType.GROUPED_AGG)
def pandas_udf_agg(value: pd.Series) -> float:
"""Aggregate group β Arrow serialization."""
return value.mean()
# Apply Aggregate
result_agg = df.groupBy("category").agg(
pandas_udf_agg(F.col("value")).alias("avg_value")
)
# Arrow serialization benefits:
# 1. Zero-copy transfer between JVM and Python
# 2. Columnar format (batch processing)
# 3. No Pickle overhead
# 4. Type-safe (schema enforced)
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 80" width="100%" style={{ maxWidth: 700 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="udf-arr" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<linearGradient id="udf-val" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#3b82f6"/>
<stop offset="100%" stopColor="#2563eb"/>
</linearGradient>
<linearGradient id="udf-null" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#f59e0b"/>
<stop offset="100%" stopColor="#d97706"/>
</linearGradient>
<linearGradient id="udf-off" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#10b981"/>
<stop offset="100%" stopColor="#059669"/>
</linearGradient>
<filter id="udf-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.15"/>
</filter>
</defs>
<rect x="10" y="10" width="780" height="60" rx="14" fill="#f8fafc" filter="url(#udf-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
<text x="400" y="30" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Arrow Array</text>
<rect x="30" y="38" width="240" height="25" rx="8" fill="url(#udf-val)" filter="url(#udf-shadow)"/>
<text x="150" y="55" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Value Buffer (contiguous)</text>
<rect x="280" y="38" width="230" height="25" rx="8" fill="url(#udf-null)" filter="url(#udf-shadow)"/>
<text x="395" y="55" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Null Bitmap (bit-packed)</text>
<rect x="520" y="38" width="260" height="25" rx="8" fill="url(#udf-off)" filter="url(#udf-shadow)"/>
<text x="650" y="55" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Offset Array (if variable)</text>
</svg>
</div>
4. Arrow Configuration and Tuning
# Arrow configuration:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")
spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "true")
# Batch size tuning:
# Larger batches: better throughput, more memory
# Smaller batches: lower latency, less memory
# Default: 10000 rows per batch
# Memory calculation:
# Batch memory = batch_size Γ row_size Γ 2 (input + output)
# For 10000 rows Γ 100 bytes = 1MB per batch
# With 200 partitions: 200MB total Arrow memory
# Arrow memory limit:
spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "64m")
# Arrow fallback to non-Arrow:
# If Arrow fails (unsupported types), falls back to Pickle
# Disable for strict Arrow usage:
spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
# Arrow type mapping:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 240" width="100%" style={{ maxWidth: 650 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="arr-hdr" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<filter id="arr-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
</filter>
</defs>
<rect x="10" y="10" width="780" height="220" rx="14" fill="#fff" filter="url(#arr-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
<rect x="10" y="10" width="780" height="30" rx="14" fill="url(#arr-hdr)"/>
<rect x="10" y="24" width="780" height="16" fill="url(#arr-hdr)"/>
<text x="160" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Spark Type</text>
<text x="400" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Arrow Type</text>
<text x="620" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Python Type</text>
<text x="160" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">IntegerType</text>
<text x="400" y="56" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">INT32</text>
<text x="620" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">int</text>
<line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">LongType</text>
<text x="400" y="80" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">INT64</text>
<text x="620" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">int</text>
<line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">FloatType</text>
<text x="400" y="104" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">FLOAT</text>
<text x="620" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">float</text>
<line x1="30" y1="114" x2="770" y2="114" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">DoubleType</text>
<text x="400" y="128" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">DOUBLE</text>
<text x="620" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">float</text>
<line x1="30" y1="138" x2="770" y2="138" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="152" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">StringType</text>
<text x="400" y="152" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">UTF8</text>
<text x="620" y="152" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">str</text>
<line x1="30" y1="162" x2="770" y2="162" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="176" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">BinaryType</text>
<text x="400" y="176" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">BINARY</text>
<text x="620" y="176" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">bytes</text>
<line x1="30" y1="186" x2="770" y2="186" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="200" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">BooleanType</text>
<text x="400" y="200" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">BOOL</text>
<text x="620" y="200" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">bool</text>
<line x1="30" y1="210" x2="770" y2="210" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="160" y="224" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">DateType</text>
<text x="400" y="224" textAnchor="middle" fill="#3b82f6" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">DATE32</text>
<text x="620" y="224" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">datetime.date</text>
</svg>
</div>
5. Built-in Functions vs UDFs
# Built-in functions: Always prefer over UDFs
# Reason: Catalyst can optimize them (predicate pushdown, codegen)
# Example: COALESCE vs UDF
# COALESCE: Built-in, optimized, no serialization
result = df.withColumn("filled", F.coalesce(F.col("value"), F.lit(0.0)))
# UDF: Not optimized, serialization overhead
@F.udf(returnType=DoubleType())
def fill_udf(value):
return value if value is not None else 0.0
result_udf = df.withColumn("filled", fill_udf(F.col("value")))
# Performance comparison:
# COALESCE: ~100M rows/sec (JVM, codegen, no serialization)
# UDF: ~1M rows/sec (serialization, Python overhead)
# Difference: 100x slower
# When to use built-in functions:
# 1. String operations: F.upper(), F.lower(), F.concat()
# 2. Math operations: F.abs(), F.round(), F.pow()
# 3. Date operations: F.date_format(), F.datediff()
# 4. Conditional logic: F.when().otherwise()
# 5. Aggregations: F.sum(), F.avg(), F.count()
# 6. Window functions: F.row_number(), F.lead(), F.lag()
# When UDFs are necessary:
# 1. Complex business logic not expressible in SQL
# 2. External API calls (ML inference, geocoding)
# 3. Custom algorithms (not in Spark MLlib)
# 4. Legacy Python code integration
6. UDF Optimization Strategies
# Strategy 1: Use Pandas UDF instead of Python UDF
# 10-100x improvement with minimal code changes
# Strategy 2: Batch operations in Pandas UDF
@F.pandas_udf(DoubleType())
def optimized_pandas_udf(values: pd.Series) -> pd.Series:
"""Vectorized operation β NumPy/Pandas optimized."""
# NumPy operations are 10-100x faster than Python loops
return values * 2 + 1 # Vectorized operation
# Strategy 3: Minimize data transfer
@F.pandas_udf(DoubleType())
def minimal_transfer_udf(value: pd.Series) -> pd.Series:
"""Process only needed columns."""
# Don't pass entire row if only one column needed
return value.clip(lower=0) # NumPy clip operation
# Strategy 4: Use Arrow for complex types
@F.pandas_udf(ArrayType(DoubleType()))
def arrow_array_udf(values: pd.Series) -> pd.Series:
"""Process arrays efficiently with Arrow."""
return values.apply(lambda x: [v * 2 for v in x])
# Strategy 5: Cache intermediate results
# If UDF is called multiple times, cache the result
df_cached = df.cache()
result1 = df_cached.withColumn("result1", my_udf(F.col("value")))
result2 = df_cached.withColumn("result2", my_udf(F.col("value")))
# Strategy 6: Avoid UDF in joins/filters
# Catalyst cannot optimize UDFs in predicates
# Instead: pre-compute UDF result, then filter
df_with_result = df.withColumn("computed", my_udf(F.col("value")))
filtered = df_with_result.filter(F.col("computed") > 100)
# Benchmark comparison:
import time
def benchmark(df, udf_func, iterations=10):
"""Benchmark UDF performance."""
times = []
for _ in range(iterations):
start = time.time()
df.withColumn("result", udf_func(F.col("value"))).count()
times.append(time.time() - start)
return sum(times) / len(times)
# Results (10M rows):
# Built-in: 0.8 seconds
# Pandas UDF: 4.2 seconds (5x slower)
# Python UDF: 85.3 seconds (106x slower)
7. UDF Testing and Debugging
# Test UDF independently:
import pandas as pd
import numpy as np
def test_pandas_udf():
"""Test Pandas UDF with sample data."""
test_input = pd.Series([1.0, 2.0, 3.0, None, 5.0])
expected_output = pd.Series([3.0, 5.0, 7.0, 1.0, 11.0])
result = pandas_udf_scalar(test_input)
pd.testing.assert_series_equal(result, expected_output)
# Debug UDF in Spark:
# 1. Use local mode for debugging
spark = SparkSession.builder \
.master("local[*]") \
.appName("DebugUDF") \
.getOrCreate()
# 2. Test with small dataset
test_df = spark.createDataFrame(
[(1.0,), (2.0,), (None,), (4.0,)],
["value"]
)
# 3. Show results
test_df.withColumn("result", pandas_udf_scalar(F.col("value"))).show()
# 4. Check UDF execution plan
test_df.withColumn("result", pandas_udf_scalar(F.col("value"))).explain()
# Common UDF issues:
# 1. Type mismatch: UDF return type doesn't match actual return
# 2. Null handling: UDF doesn't handle None values
# 3. Performance: Using Python loops instead of vectorized operations
# 4. Serialization: UDF uses unsupported types for Arrow
# Null handling example:
@F.pandas_udf(DoubleType())
def null_safe_udf(value: pd.Series) -> pd.Series:
"""Handle nulls explicitly."""
return value.fillna(0.0) * 2 + 1
β οΈCommon Pitfall
Using Python UDFs for simple operations that Spark already supports (like string manipulation, math, or date functions) is a major performance anti-pattern. Always check if a built-in function exists before writing a UDF.
π‘Interview Tip
When discussing UDF performance, always mention the serialization overhead: Python UDFs serialize each row individually (O(N) serialization), while Pandas UDFs serialize entire batches (O(N/batch_size) serialization). This is why Pandas UDFs are 10-100x faster.
Summary
| UDF Type | Performance | Serialization | Optimization | Best For |
|---|---|---|---|---|
| Built-in Function | 1x (fastest) | None (JVM) | Full Catalyst | All standard operations |
| Pandas UDF | 5-10x slower | Arrow (zero-copy) | Vectorized | Complex batch operations |
| Python UDF | 100-1000x slower | Pickle (slow) | None | External APIs, legacy code |
The key to UDF optimization is: always prefer built-in functions, use Pandas UDFs when custom logic is needed, and avoid Python UDFs unless absolutely necessary.