Testing & Mocking in Python
Difficulty: Medium | Companies: Google, Meta, Amazon, Netflix, Stripe
Pytest Fundamentals
import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock
from typing import List, Dict
from datetime import datetime
import asyncio
# Sample code to test
class UserService:
"""Example service for testing."""
def __init__(self, db_session, email_service):
self.db = db_session
self.email_service = email_service
def create_user(self, user_data: Dict) -> Dict:
"""Create a new user."""
if not user_data.get('email'):
raise ValueError("Email is required")
# Check if user exists
existing = self.db.query("SELECT * FROM users WHERE email = ?", user_data['email'])
if existing:
raise ValueError("User already exists")
# Create user
user = self.db.execute(
"INSERT INTO users (email, name) VALUES (?, ?)",
user_data['email'], user_data['name']
)
# Send welcome email
self.email_service.send(user_data['email'], "Welcome!")
return user
def get_user(self, user_id: int) -> Dict:
"""Get user by ID."""
user = self.db.query("SELECT * FROM users WHERE id = ?", user_id)
if not user:
raise ValueError("User not found")
return user
# Fixtures
@pytest.fixture
def mock_db():
"""Mock database session."""
db = Mock()
db.query.return_value = None
db.execute.return_value = {"id": 1, "email": "test@example.com", "name": "Test User"}
return db
@pytest.fixture
def mock_email():
"""Mock email service."""
email_service = Mock()
email_service.send.return_value = True
return email_service
@pytest.fixture
def user_service(mock_db, mock_email):
"""User service with mocked dependencies."""
return UserService(mock_db, mock_email)
@pytest.fixture
def sample_user_data():
"""Sample user data."""
return {"email": "test@example.com", "name": "Test User"}
# Test Classes
class TestUserService:
"""Test suite for UserService."""
def test_create_user_success(self, user_service, sample_user_data, mock_db, mock_email):
"""Test successful user creation."""
result = user_service.create_user(sample_user_data)
assert result['email'] == sample_user_data['email']
mock_db.execute.assert_called_once()
mock_email.send.assert_called_once_with(sample_user_data['email'], "Welcome!")
def test_create_user_missing_email(self, user_service):
"""Test user creation with missing email."""
with pytest.raises(ValueError, match="Email is required"):
user_service.create_user({"name": "Test"})
def test_create_user_duplicate_email(self, user_service, sample_user_data, mock_db):
"""Test user creation with duplicate email."""
mock_db.query.return_value = {"id": 1}
with pytest.raises(ValueError, match="User already exists"):
user_service.create_user(sample_user_data)
def test_get_user_found(self, user_service, mock_db):
"""Test getting existing user."""
mock_db.query.return_value = {"id": 1, "email": "test@example.com"}
result = user_service.get_user(1)
assert result['id'] == 1
mock_db.query.assert_called_once()
def test_get_user_not_found(self, user_service, mock_db):
"""Test getting non-existent user."""
mock_db.query.return_value = None
with pytest.raises(ValueError, match="User not found"):
user_service.get_user(999)
# Parametrized Tests
@pytest.mark.parametrize("email,expected_valid", [
("valid@example.com", True),
("invalid-email", False),
("@nodomain.com", False),
("user@.com", False),
("user@domain.co", True),
])
def test_email_validation(email, expected_valid):
"""Test email validation with various inputs."""
import re
is_valid = bool(re.match(r'^[^@]+@[^@]+\.[^@]+$', email))
assert is_valid == expected_valid
# Markers
@pytest.mark.slow
def test_slow_operation():
"""Test that takes a long time."""
import time
time.sleep(1)
assert True
@pytest.mark.integration
def test_database_integration():
"""Integration test with real database."""
# This would use a test database
pass
βΉοΈ
Use fixtures for setup/teardown logic. Fixtures can be scoped to function, class, module, or session.
Advanced Mocking
from unittest.mock import Mock, patch, MagicMock, AsyncMock, call
from typing import AsyncGenerator
import asyncio
class APIClient:
"""Example API client for mocking."""
def __init__(self, base_url: str, api_key: str):
self.base_url = base_url
self.api_key = api_key
self.session = None
def connect(self):
"""Establish connection."""
import requests
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
def get(self, endpoint: str) -> dict:
"""Make GET request."""
response = self.session.get(f"{self.base_url}/{endpoint}")
response.raise_for_status()
return response.json()
def post(self, endpoint: str, data: dict) -> dict:
"""Make POST request."""
response = self.session.post(f"{self.base_url}/{endpoint}", json=data)
response.raise_for_status()
return response.json()
class TestAPIClient:
"""Test suite with advanced mocking."""
@patch('requests.Session')
def test_connect(self, mock_session_class):
"""Test connection establishment."""
mock_session = Mock()
mock_session_class.return_value = mock_session
client = APIClient("https://api.example.com", "test-key")
client.connect()
mock_session.headers.update.assert_called_once_with(
{"Authorization": "Bearer test-key"}
)
@patch('requests.Session')
def test_get_request(self, mock_session_class):
"""Test GET request."""
mock_session = Mock()
mock_response = Mock()
mock_response.json.return_value = {"data": "test"}
mock_response.raise_for_status = Mock()
mock_session.get.return_value = mock_response
mock_session_class.return_value = mock_session
client = APIClient("https://api.example.com", "test-key")
client.connect()
result = client.get("users")
assert result == {"data": "test"}
mock_session.get.assert_called_once_with("https://api.example.com/users")
@patch('requests.Session')
def test_api_error_handling(self, mock_session_class):
"""Test API error handling."""
import requests
mock_session = Mock()
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_session.get.return_value = mock_response
mock_session_class.return_value = mock_session
client = APIClient("https://api.example.com", "test-key")
client.connect()
with pytest.raises(requests.HTTPError):
client.get("nonexistent")
# Async Mocking
class AsyncService:
"""Example async service."""
async def fetch_data(self, url: str) -> dict:
"""Async data fetching."""
await asyncio.sleep(0.1)
return {"url": url, "data": "response"}
async def process_batch(self, items: List[str]) -> List[dict]:
"""Process multiple items asynchronously."""
results = []
for item in items:
result = await self.fetch_data(item)
results.append(result)
return results
class TestAsyncService:
"""Test async code with AsyncMock."""
@pytest.mark.asyncio
async def test_fetch_data(self):
"""Test async fetch."""
service = AsyncService()
with patch.object(service, 'fetch_data', new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = {"url": "test", "data": "mocked"}
result = await service.fetch_data("test")
assert result == {"url": "test", "data": "mocked"}
mock_fetch.assert_called_once_with("test")
@pytest.mark.asyncio
async def test_process_batch(self):
"""Test batch processing."""
service = AsyncService()
with patch.object(service, 'fetch_data', new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = {"data": "mocked"}
result = await service.process_batch(["url1", "url2", "url3"])
assert len(result) == 3
assert mock_fetch.call_count == 3
# Mocking Context Managers
class DatabaseConnection:
"""Example context manager."""
def __enter__(self):
print("Connecting to database")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Disconnecting from database")
return False
def query(self, sql: str):
return f"Result: {sql}"
def test_context_manager_mocking():
"""Test context manager with mock."""
with patch('__main__.DatabaseConnection') as MockDB:
mock_instance = Mock()
mock_instance.query.return_value = "Mocked result"
MockDB.return_value.__enter__ = Mock(return_value=mock_instance)
MockDB.return_value.__exit__ = Mock(return_value=False)
with DatabaseConnection() as db:
result = db.query("SELECT * FROM users")
assert result == "Mocked result"
mock_instance.query.assert_called_once_with("SELECT * FROM users")
Test-Driven Development
# TDD Cycle: Red -> Green -> Refactor
# Step 1: Write failing test (Red)
def test_calculate_discount():
"""Test discount calculation."""
calculator = PriceCalculator()
# Regular discount
assert calculator.calculate_discount(100, 10) == 90
# No discount
assert calculator.calculate_discount(100, 0) == 100
# Full discount
assert calculator.calculate_discount(100, 100) == 0
# Invalid discount
with pytest.raises(ValueError):
calculator.calculate_discount(100, -10)
with pytest.raises(ValueError):
calculator.calculate_discount(100, 110)
# Step 2: Write minimal code to pass (Green)
class PriceCalculator:
"""Simple price calculator."""
def calculate_discount(self, price: float, discount_percent: float) -> float:
"""Calculate discounted price."""
if discount_percent < 0 or discount_percent > 100:
raise ValueError("Discount must be between 0 and 100")
return price * (1 - discount_percent / 100)
# Step 3: Refactor (maintain passing tests)
class PriceCalculatorRefactored:
"""Refactored price calculator with more features."""
def __init__(self, tax_rate: float = 0.0):
self.tax_rate = tax_rate
def calculate_discount(self, price: float, discount_percent: float) -> float:
"""Calculate discounted price."""
self._validate_discount(discount_percent)
return price * (1 - discount_percent / 100)
def calculate_final_price(self, price: float, discount_percent: float) -> float:
"""Calculate final price with discount and tax."""
discounted = self.calculate_discount(price, discount_percent)
return discounted * (1 + self.tax_rate)
def _validate_discount(self, discount_percent: float):
"""Validate discount percentage."""
if not 0 <= discount_percent <= 100:
raise ValueError("Discount must be between 0 and 100")
# Property-based Testing with Hypothesis
from hypothesis import given, strategies as st
@given(
price=st.floats(min_value=0, max_value=1000000),
discount=st.floats(min_value=0, max_value=100)
)
def test_discount_properties(price, discount):
"""Property-based test for discount calculation."""
calculator = PriceCalculator()
result = calculator.calculate_discount(price, discount)
# Property: Result should be non-negative
assert result >= 0
# Property: Result should be less than or equal to original price
assert result <= price
# Property: Result should be price minus discount amount
expected = price * (1 - discount / 100)
assert abs(result - expected) < 0.01
# Fixture Factory Pattern
@pytest.fixture
def user_factory():
"""Factory fixture for creating test users."""
def _create_user(
email: str = "test@example.com",
name: str = "Test User",
is_active: bool = True
):
return {
"email": email,
"name": name,
"is_active": is_active,
"created_at": datetime.now()
}
return _create_user
def test_user_creation(user_factory):
"""Test using fixture factory."""
user = user_factory()
assert user["email"] == "test@example.com"
inactive_user = user_factory(is_active=False)
assert inactive_user["is_active"] == False
β οΈ
Follow the AAA pattern: Arrange, Act, Assert. Keep tests simple and focused on one behavior.
Test Configuration
# conftest.py - Shared fixtures and configuration
import pytest
import tempfile
import os
from pathlib import Path
@pytest.fixture(scope="session")
def temp_dir():
"""Create temporary directory for test session."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture(autouse=True)
def setup_test_env(monkeypatch):
"""Automatically set up test environment."""
monkeypatch.setenv("TESTING", "true")
monkeypatch.setenv("DATABASE_URL", "sqlite:///:memory:")
@pytest.fixture
def sample_config():
"""Sample configuration for testing."""
return {
"debug": True,
"database": "sqlite:///:memory:",
"secret_key": "test-secret-key"
}
# markers in pytest.ini
# [tool:pytest]
# markers =
# slow: marks tests as slow (deselect with '-m "not slow"')
# integration: marks tests as integration tests
# unit: marks tests as unit tests
# Example usage with markers
@pytest.mark.unit
def test_simple_calculation():
"""Unit test."""
assert 1 + 1 == 2
@pytest.mark.integration
def test_database_operation():
"""Integration test."""
pass
# Coverage configuration
# .coveragerc
# [run]
# source = src
# omit = tests/*
#
# [report]
# fail_under = 80
# show_missing = true
Follow-Up Questions
-
Explain the difference between mocking and stubbing.
-
When would you use integration tests vs unit tests?
-
How do you test async code effectively?
-
What is the purpose of test fixtures?
-
How do you achieve high test coverage without testing implementation details?