πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Testing Spark: LocalCluster, SharedSparkSession, Assertions

Apache SparkTesting⭐ Premium

Advertisement

Testing Spark: LocalCluster, SharedSparkSession, Assertions

Difficulty: Expert | Companies: Databricks, Netflix, Uber, Airbnb, LinkedIn

ℹ️Interview Context

Testing Spark applications is often overlooked but critical for production reliability. Interviewers expect knowledge of testing strategies, frameworks, and how to test Spark code efficiently.

Question

How do you test Spark applications effectively? Compare different testing approaches: local mode, LocalCluster, and SharedSparkSession. What are the best practices for testing DataFrame transformations, Spark SQL queries, and streaming queries? How do you handle flaky tests and slow test suites?


Detailed Answer

1. Testing Approaches Overview

# Testing approaches comparison:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 140" width="100%" style={{ maxWidth: 700 }} xmlns="http://www.w3.org/2000/svg">
  <defs>
    <linearGradient id="test-hdr" x1="0" y1="0" x2="1" y2="1">
      <stop offset="0%" stopColor="#6366f1"/>
      <stop offset="100%" stopColor="#4f46e5"/>
    </linearGradient>
    <filter id="test-shadow">
      <feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
    </filter>
  </defs>
  <rect x="10" y="10" width="780" height="120" rx="14" fill="#fff" filter="url(#test-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
  <rect x="10" y="10" width="780" height="30" rx="14" fill="url(#test-hdr)"/>
  <rect x="10" y="24" width="780" height="16" fill="url(#test-hdr)"/>
  <text x="140" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Approach</text>
  <text x="340" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Speed</text>
  <text x="500" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Realism</text>
  <text x="680" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Complexity</text>
  <text x="140" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Local Mode</text>
  <text x="340" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Fast</text>
  <text x="500" y="56" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Low</text>
  <text x="680" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Simple</text>
  <line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="140" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">LocalCluster</text>
  <text x="340" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
  <text x="500" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
  <text x="680" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
  <line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="140" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">SharedSparkSession</text>
  <text x="340" y="104" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
  <text x="500" y="104" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">High</text>
  <text x="680" y="104" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
  <line x1="30" y1="114" x2="770" y2="114" stroke="#e2e8f0" strokeWidth="0.5"/>
  <text x="140" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Integration Test</text>
  <text x="340" y="128" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Slow</text>
  <text x="500" y="128" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">High</text>
  <text x="680" y="128" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Complex</text>
</svg>
</div>

# Recommendation: Use SharedSparkSession for unit tests
# Use integration tests for end-to-end validation

2. SharedSparkSession (Recommended)

import pytest
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

@pytest.fixture(scope="session")
def spark():
    """Create shared SparkSession for all tests."""
    spark = SparkSession.builder \
        .master("local[*]") \
        .appName("test") \
        .config("spark.ui.enabled", "false") \
        .config("spark.sql.shuffle.partitions", "2") \
        .config("spark.driver.bindAddress", "127.0.0.1") \
        .getOrCreate()
    
    yield spark
    
    spark.stop()

@pytest.fixture(scope="function")
def sample_data(spark):
    """Create sample data for tests."""
    return spark.createDataFrame(
        [(1, "Alice", 100.0),
         (2, "Bob", 200.0),
         (3, "Charlie", 150.0)],
        ["id", "name", "amount"]
    )

def test_basic_transformation(spark, sample_data):
    """Test basic DataFrame transformation."""
    result = sample_data.withColumn(
        "doubled", F.col("amount") * 2
    )
    
    # Assert schema
    assert result.schema["doubled"].dataType == DoubleType()
    
    # Assert values
    rows = result.collect()
    assert len(rows) == 3
    assert rows[0]["doubled"] == 200.0
    assert rows[1]["doubled"] == 400.0
    assert rows[2]["doubled"] == 300.0

def test_filter(spark, sample_data):
    """Test filter operation."""
    result = sample_data.filter(F.col("amount") > 120)
    
    assert result.count() == 2
    assert set(result.select("name").collect()) == {("Bob",), ("Charlie",)}

def test_aggregation(spark, sample_data):
    """Test aggregation."""
    result = sample_data.agg(
        F.sum("amount").alias("total"),
        F.avg("amount").alias("avg")
    ).collect()[0]
    
    assert result["total"] == 450.0
    assert result["avg"] == 150.0

3. LocalCluster Mode

# LocalCluster mode: Runs executors in separate JVMs
# More realistic than local mode (separate processes)

@pytest.fixture(scope="session")
def spark_cluster():
    """Create SparkSession with LocalCluster."""
    spark = SparkSession.builder \
        .master("local-cluster[2, 4, 1024]")  # 2 executors, 4 cores, 1GB each
        .appName("test-cluster") \
        .config("spark.ui.enabled", "false") \
        .config("spark.executor.memory", "1g") \
        .getOrCreate()
    
    yield spark
    
    spark.stop()

def test_with_cluster(spark_cluster):
    """Test with LocalCluster (more realistic)."""
    df = spark_cluster.range(1000000)
    
    # This test runs with actual executor processes
    result = df.repartition(10).count()
    assert result == 1000000

4. Test Assertions

# Comprehensive assertion strategies:

def assert_dataframe_equal(actual, expected, check_schema=True, 
                           check_values=True, check_row_order=True):
    """Assert two DataFrames are equal."""
    
    # Check schema
    if check_schema:
        assert actual.schema == expected.schema, \
            f"Schema mismatch: {actual.schema} != {expected.schema}"
    
    # Check row count
    actual_count = actual.count()
    expected_count = expected.count()
    assert actual_count == expected_count, \
        f"Row count mismatch: {actual_count} != {expected_count}"
    
    # Check values
    if check_values:
        if check_row_order:
            actual_rows = actual.collect()
            expected_rows = expected.collect()
            assert actual_rows == expected_rows, \
                f"Values mismatch:\nActual: {actual_rows}\nExpected: {expected_rows}"
        else:
            actual_set = set(actual.collect())
            expected_set = set(expected.collect())
            assert actual_set == expected_set, \
                f"Values mismatch (order ignored)"

def test_with_approx(spark):
    """Test with approximate floating point comparison."""
    df = spark.createDataFrame([(1.0,)], ["value"])
    result = df.withColumn("sqrt", F.sqrt(F.col("value"))).collect()[0]
    
    assert result["sqrt"] == pytest.approx(1.0, rel=1e-6)

def test_schema(spark):
    """Test DataFrame schema."""
    df = spark.createDataFrame([], "id: int, name: string, amount: double")
    
    expected_schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("name", StringType(), True),
        StructField("amount", DoubleType(), True)
    ])
    
    assert df.schema == expected_schema

def test_exceptions(spark):
    """Test that exceptions are raised correctly."""
    df = spark.createDataFrame([], "id: int")
    
    with pytest.raises(AnalysisException):
        df.select("nonexistent_column").collect()

5. Testing Transformations

# Test complex transformations:

def test_complex_transformation(spark):
    """Test a complex transformation pipeline."""
    
    # Input data
    input_df = spark.createDataFrame(
        [(1, "2024-01-01", 100.0),
         (2, "2024-01-02", 200.0),
         (3, "2024-01-03", None)],
        ["id", "date", "amount"]
    )
    
    # Transformation under test
    def transform(df):
        return df \
            .withColumn("date", F.to_date("date")) \
            .withColumn("amount", F.coalesce(F.col("amount"), F.lit(0.0))) \
            .withColumn("year", F.year("date")) \
            .withColumn("month", F.month("date")) \
            .filter(F.col("amount") > 0)
    
    # Apply transformation
    result = transform(input_df)
    
    # Assert results
    expected = spark.createDataFrame(
        [(1, 1, 1, 100.0),
         (2, 1, 2, 200.0)],
        ["id", "year", "month", "amount"]
    )
    
    assert_dataframe_equal(result, expected, check_schema=False)

def test_window_functions(spark):
    """Test window functions."""
    df = spark.createDataFrame(
        [("A", 1), ("A", 2), ("B", 3), ("B", 4)],
        ["group", "value"]
    )
    
    window = Window.partitionBy("group").orderBy("value")
    result = df.withColumn("rank", F.row_number().over(window))
    
    expected = spark.createDataFrame(
        [("A", 1, 1), ("A", 2, 2), ("B", 3, 1), ("B", 4, 2)],
        ["group", "value", "rank"]
    )
    
    assert_dataframe_equal(result, expected)

6. Testing Spark SQL

# Test Spark SQL queries:

def test_spark_sql(spark):
    """Test Spark SQL queries."""
    # Create temp view
    df = spark.createDataFrame(
        [(1, "Alice"), (2, "Bob")],
        ["id", "name"]
    )
    df.createOrReplaceTempView("users")
    
    # SQL query under test
    result = spark.sql("""
        SELECT id, name, 
               CASE WHEN id = 1 THEN 'First' ELSE 'Other' END as position
        FROM users
        WHERE id > 0
    """)
    
    expected = spark.createDataFrame(
        [(1, "Alice", "First"), (2, "Bob", "Other")],
        ["id", "name", "position"]
    )
    
    assert_dataframe_equal(result, expected)

def test_udf_in_sql(spark):
    """Test UDF used in SQL."""
    @F.udf(returnType=StringType())
    def upper_udf(s):
        return s.upper() if s else None
    
    spark.udf.register("upper_udf", upper_udf)
    
    df = spark.createDataFrame([("alice",)], ["name"])
    df.createOrReplaceTempView("names")
    
    result = spark.sql("SELECT upper_udf(name) as upper_name FROM names")
    
    expected = spark.createDataFrame([("ALICE",)], ["upper_name"])
    assert_dataframe_equal(result, expected)

7. Testing Streaming

# Test Structured Streaming:

def test_streaming_query(spark):
    """Test streaming query with test data."""
    from pyspark.sql.streaming import StreamingQuery
    
    # Create streaming source from rate
    stream_df = spark.readStream \
        .format("rate") \
        .option("rowsPerSecond", 10) \
        .load()
    
    # Apply transformation
    result = stream_df \
        .withColumn("double_value", F.col("value") * 2) \
        .select("timestamp", "double_value")
    
    # Write to memory sink
    query = result.writeStream \
        .format("memory") \
        .queryName("test_output") \
        .outputMode("append") \
        .start()
    
    # Wait for data
    query.processAllAvailable()
    
    # Read results
    output_df = spark.sql("SELECT * FROM test_output")
    
    # Assert
    assert output_df.count() > 0
    assert "double_value" in output_df.columns
    
    query.stop()

def test_streaming_with_watermark(spark):
    """Test streaming with watermark."""
    stream_df = spark.readStream \
        .format("rate") \
        .option("rowsPerSecond", 10) \
        .load() \
        .withColumn("event_time", F.current_timestamp())
    
    windowed = stream_df \
        .withWatermark("event_time", "10 seconds") \
        .groupBy(
            F.window("event_time", "5 seconds"),
        ).count()
    
    query = windowed.writeStream \
        .format("memory") \
        .queryName("windowed_output") \
        .outputMode("update") \
        .start()
    
    query.processAllAvailable()
    
    output_df = spark.sql("SELECT * FROM windowed_output")
    assert output_df.count() > 0
    
    query.stop()

8. Performance Testing

# Test performance characteristics:

def test_performance(spark):
    """Test that transformation meets performance requirements."""
    import time
    
    # Create large dataset
    df = spark.range(10000000).withColumn(
        "value", F.randn()
    )
    
    # Measure execution time
    start_time = time.time()
    result = df.groupBy(F.floor(F.col("id") / 1000)).agg(
        F.sum("value").alias("total")
    ).count()
    execution_time = time.time() - start_time
    
    # Assert performance requirement
    assert execution_time < 10.0, \
        f"Transformation too slow: {execution_time:.2f}s > 10s"
    
    # Assert result correctness
    assert result == 10000

def test_memory_usage(spark):
    """Test memory usage doesn't exceed threshold."""
    import psutil
    import os
    
    process = psutil.Process(os.getpid())
    memory_before = process.memory_info().rss / 1024 / 1024  # MB
    
    # Perform operation
    df = spark.range(1000000)
    result = df.cache()
    result.count()
    
    memory_after = process.memory_info().rss / 1024 / 1024
    memory_delta = memory_after - memory_before
    
    # Assert memory usage
    assert memory_delta < 500, \
        f"Memory usage too high: {memory_delta:.1f}MB > 500MB"
    
    result.unpersist()

9. Test Organization

# Project structure:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 240" width="100%" style={{ maxWidth: 500 }} xmlns="http://www.w3.org/2000/svg">
  <defs>
    <linearGradient id="proj-root" x1="0" y1="0" x2="0" y2="1">
      <stop offset="0%" stopColor="#6366f1"/>
      <stop offset="100%" stopColor="#4f46e5"/>
    </linearGradient>
    <linearGradient id="proj-folder" x1="0" y1="0" x2="0" y2="1">
      <stop offset="0%" stopColor="#3b82f6"/>
      <stop offset="100%" stopColor="#2563eb"/>
    </linearGradient>
    <linearGradient id="proj-file" x1="0" y1="0" x2="0" y2="1">
      <stop offset="0%" stopColor="#10b981"/>
      <stop offset="100%" stopColor="#059669"/>
    </linearGradient>
    <filter id="proj-shadow">
      <feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
    </filter>
  </defs>
  <rect x="20" y="10" width="120" height="28" rx="8" fill="url(#proj-root)" filter="url(#proj-shadow)"/>
  <text x="80" y="29" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">tests/</text>
  <line x1="80" y1="38" x2="80" y2="55" stroke="#94a3b8" strokeWidth="1.5"/>
  <line x1="80" y1="55" x2="30" y2="55" stroke="#94a3b8" strokeWidth="1.5"/>
  <line x1="80" y1="75" x2="30" y2="75" stroke="#94a3b8" strokeWidth="1.5"/>
  <line x1="80" y1="95" x2="30" y2="95" stroke="#94a3b8" strokeWidth="1.5"/>
  <line x1="80" y1="175" x2="30" y2="175" stroke="#94a3b8" strokeWidth="1.5"/>
  <line x1="30" y1="55" x2="30" y2="175" stroke="#94a3b8" strokeWidth="1.5"/>
  <rect x="40" y="48" width="180" height="20" rx="5" fill="url(#proj-file)" filter="url(#proj-shadow)"/>
  <text x="130" y="62" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">conftest.py</text>
  <text x="240" y="62" fill="#6b7280" fontFamily="Inter,system-ui,sans-serif" fontSize="9">Shared fixtures</text>
  <rect x="40" y="70" width="100" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
  <text x="90" y="84" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">unit/</text>
  <line x1="90" y1="90" x2="90" y2="100" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="90" y1="100" x2="60" y2="100" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="90" y1="115" x2="60" y2="115" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="90" y1="130" x2="60" y2="130" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="60" y1="100" x2="60" y2="130" stroke="#cbd5e1" strokeWidth="1"/>
  <rect x="70" y="96" width="180" height="16" rx="4" fill="#d1fae5"/>
  <text x="160" y="108" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_transformations.py</text>
  <rect x="70" y="112" width="180" height="16" rx="4" fill="#d1fae5"/>
  <text x="160" y="124" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_sql_queries.py</text>
  <rect x="70" y="128" width="180" height="16" rx="4" fill="#d1fae5"/>
  <text x="160" y="140" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_udfs.py</text>
  <rect x="40" y="95" width="130" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
  <text x="105" y="109" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">integration/</text>
  <line x1="105" y1="115" x2="105" y2="150" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="105" y1="150" x2="70" y2="150" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="105" y1="163" x2="70" y2="163" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="70" y1="150" x2="70" y2="163" stroke="#cbd5e1" strokeWidth="1"/>
  <rect x="80" y="146" width="190" height="16" rx="4" fill="#dbeafe"/>
  <text x="175" y="158" textAnchor="middle" fill="#1e40af" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_data_pipeline.py</text>
  <rect x="80" y="160" width="190" height="16" rx="4" fill="#dbeafe"/>
  <text x="175" y="172" textAnchor="middle" fill="#1e40af" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_streaming.py</text>
  <rect x="40" y="170" width="140" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
  <text x="110" y="184" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">performance/</text>
  <line x1="110" y1="190" x2="110" y2="205" stroke="#cbd5e1" strokeWidth="1"/>
  <line x1="110" y1="205" x2="70" y2="205" stroke="#cbd5e1" strokeWidth="1"/>
  <rect x="80" y="200" width="180" height="16" rx="4" fill="#fef3c7"/>
  <text x="170" y="212" textAnchor="middle" fill="#92400e" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_benchmarks.py</text>
</svg>
</div>

# conftest.py:
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    spark = SparkSession.builder \
        .master("local[*]") \
        .appName("test") \
        .config("spark.ui.enabled", "false") \
        .config("spark.sql.shuffle.partitions", "2") \
        .getOrCreate()
    yield spark
    spark.stop()

# Running tests:
# pytest tests/ -v                    # Run all tests
# pytest tests/unit/ -v               # Run unit tests only
# pytest tests/ -k "test_filter"      # Run specific test
# pytest tests/ --cov=src             # With coverage
# pytest tests/ -x                    # Stop on first failure
# pytest tests/ --timeout=60          # Timeout per test

⚠️Common Pitfall

Not stopping SparkSession between tests causes resource leaks and flaky tests. Always use fixtures with proper cleanup, or use spark.stop() in teardown.

πŸ’‘Interview Tip

When discussing testing, mention that test data should be deterministic. Avoid using F.rand() in test data without setting a seed, as non-deterministic tests are flaky and hard to debug.


Summary

ApproachSpeedUse CaseKey Benefit
SharedSparkSessionFastUnit testsQuick feedback
LocalClusterMediumIntegration testsRealistic execution
Memory SinkFastStreaming testsNo external dependencies
Performance TestsSlowBenchmarksCatch regressions

The key to Spark testing is: use SharedSparkSession for speed, create deterministic test data, and test both correctness and performance.

Advertisement