Python Unit Testing — pytest & Test-Driven Development
Testing ensures your code works correctly and continues to work as you make changes. pytest is the most popular testing framework for Python, offering powerful features like fixtures, parametrize, and plugins.
Learning Objectives
- Write tests with pytest assertions and conventions
- Use fixtures for test setup, teardown, and dependency injection
- Parametrize tests for data-driven testing
- Mock external dependencies (APIs, databases, file systems)
- Measure test coverage and follow TDD principles
- Organize tests with conftest.py and test classes
pytest Basics
# calculator.py
def add(a, b):
"""Add two numbers."""
return a + b
def subtract(a, b):
"""Subtract b from a."""
return a - b
def multiply(a, b):
"""Multiply two numbers."""
return a * b
def divide(a, b):
"""Divide a by b."""
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
# test_calculator.py
import pytest
from calculator import add, subtract, multiply, divide
def test_add_positive():
assert add(2, 3) == 5
def test_add_negative():
assert add(-1, -1) == -2
def test_add_zero():
assert add(0, 5) == 5
def test_subtract():
assert subtract(10, 3) == 7
def test_multiply():
assert multiply(4, 5) == 20
def test_divide():
assert divide(10, 2) == 5.0
def test_divide_by_zero():
with pytest.raises(ValueError, match="Cannot divide by zero"):
divide(10, 0)
# Run: pytest test_calculator.py -v
Assertion Techniques
import pytest
def test_assertions():
# Basic equality
assert 2 + 2 == 4
# Check truthiness
assert [1, 2, 3]
assert not []
# Check membership
assert 3 in [1, 2, 3]
assert "hello" in "hello world"
# Check exceptions
with pytest.raises(ZeroDivisionError):
1 / 0
with pytest.raises(ValueError) as exc_info:
int("not a number")
assert "invalid literal" in str(exc_info.value)
# Check approximate equality for floats
assert 0.1 + 0.2 == pytest.approx(0.3)
# Custom error message
result = add(2, 2)
assert result == 4, f"Expected 4, got {result}"
Fixtures
Fixtures provide test setup, teardown, and dependency injection:
import pytest
# Simple fixture
@pytest.fixture
def sample_data():
"""Provide sample data for tests."""
return {
"users": ["Alice", "Bob", "Charlie"],
"count": 3,
"scores": [85, 92, 78]
}
def test_user_count(sample_data):
assert sample_data["count"] == len(sample_data["users"])
def test_average_score(sample_data):
avg = sum(sample_data["scores"]) / len(sample_data["scores"])
assert avg == pytest.approx(85.0)
Fixtures with Setup and Teardown
import pytest
import tempfile
import os
@pytest.fixture
def temp_file():
"""Create a temporary file for testing."""
# Setup
fd, path = tempfile.mkstemp(suffix='.txt')
os.write(fd, b"test content")
os.close(fd)
yield path # Provide the path to the test
# Teardown
os.unlink(path)
def test_read_file(temp_file):
with open(temp_file, 'r') as f:
content = f.read()
assert content == "test content"
def test_file_exists(temp_file):
assert os.path.exists(temp_file)
Database Fixture
import pytest
import sqlite3
@pytest.fixture
def db():
"""Create a test database with schema."""
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
# Create tables
cursor.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
)
''')
# Insert test data
cursor.executemany(
"INSERT INTO users (name, email) VALUES (?, ?)",
[
("Alice", "alice@test.com"),
("Bob", "bob@test.com"),
]
)
conn.commit()
yield conn
conn.close()
def test_insert_user(db):
db.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
("Charlie", "charlie@test.com")
)
db.commit()
result = db.execute("SELECT COUNT(*) FROM users").fetchone()
assert result[0] == 3
def test_query_users(db):
users = db.execute("SELECT name FROM users ORDER BY name").fetchall()
names = [user[0] for user in users]
assert names == ["Alice", "Bob"]
Fixture Scope
import pytest
# scope='function' — runs for each test function (default)
@pytest.fixture
def function_fixture():
return "per test"
# scope='class' — runs once per test class
@pytest.fixture(scope='class')
def class_fixture():
return "per class"
# scope='module' — runs once per module
@pytest.fixture(scope='module')
def module_fixture():
return "per module"
# scope='session' — runs once for entire test session
@pytest.fixture(scope='session')
def session_fixture():
return "per session"
# Use autouse for fixtures that should always run
@pytest.fixture(autouse=True)
def setup_logging():
"""Automatically run for every test."""
print("\nSetting up logging")
yield
print("\nTearing down logging")
Parametrize
Data-driven testing with multiple inputs:
import pytest
def double(x):
return x * 2
@pytest.mark.parametrize("input,expected", [
(1, 2),
(2, 4),
(3, 6),
(0, 0),
(-1, -2),
(-5, -10),
])
def test_double(input, expected):
assert double(input) == expected
Multiple Parameters
import pytest
def calculate_bmi(weight_kg, height_m):
return weight_kg / (height_m ** 2)
@pytest.mark.parametrize("weight,height,expected", [
(70, 1.75, 22.86),
(80, 1.80, 24.69),
(60, 1.60, 23.44),
(90, 1.90, 24.93),
])
def test_bmi(weight, height, expected):
assert calculate_bmi(weight, height) == pytest.approx(expected, rel=1e-2)
# Combine multiple parametrize decorators (cartesian product)
@pytest.mark.parametrize("x", [1, 2, 3])
@pytest.mark.parametrize("y", [10, 20])
def test_multiply(x, y):
result = x * y
assert result == x * y
# Runs 6 tests: (1,10), (1,20), (2,10), (2,20), (3,10), (3,20)
Parametrize with IDs
import pytest
def parse_url(url):
"""Simple URL parser."""
from urllib.parse import urlparse
parsed = urlparse(url)
return {
'scheme': parsed.scheme,
'host': parsed.hostname,
'path': parsed.path
}
@pytest.mark.parametrize("url,expected", [
("https://example.com/path", {"scheme": "https", "host": "example.com", "path": "/path"}),
("http://localhost:8080/api", {"scheme": "http", "host": "localhost", "path": "/api"}),
("ftp://files.example.com/data", {"scheme": "ftp", "host": "files.example.com", "path": "/data"}),
], ids=["https-url", "localhost-url", "ftp-url"])
def test_parse_url(url, expected):
result = parse_url(url)
assert result == expected
Mocking
Mock external dependencies to isolate units under test:
from unittest.mock import patch, MagicMock, AsyncMock
import requests
# Function that depends on external API
def get_user_data(user_id):
response = requests.get(f"https://api.example.com/users/{user_id}")
response.raise_for_status()
return response.json()
# Test with mock
@patch('requests.get')
def test_get_user(mock_get):
# Configure mock
mock_response = MagicMock()
mock_response.json.return_value = {"id": 1, "name": "Alice"}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
# Call function
result = get_user_data(1)
# Assertions
assert result["name"] == "Alice"
mock_get.assert_called_once_with("https://api.example.com/users/1")
Mocking File Operations
from unittest.mock import mock_open, patch
import json
def read_config(filepath):
"""Read configuration from JSON file."""
with open(filepath, 'r') as f:
return json.load(f)
@patch('builtins.open', mock_open(read_data='{"debug": true, "port": 8080}'))
@patch('json.load')
def test_read_config(mock_load):
mock_load.return_value = {"debug": True, "port": 8080}
config = read_config('config.json')
assert config["debug"] is True
assert config["port"] == 8080
mock_load.assert_called_once()
Mocking Database Calls
from unittest.mock import patch, MagicMock
class UserService:
def __init__(self, db):
self.db = db
def get_user(self, user_id):
return self.db.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
def create_user(self, name, email):
self.db.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
(name, email)
)
self.db.commit()
def test_get_user():
mock_db = MagicMock()
mock_db.execute.return_value.fetchone.return_value = {"id": 1, "name": "Alice"}
service = UserService(mock_db)
user = service.get_user(1)
assert user["name"] == "Alice"
mock_db.execute.assert_called_once()
def test_create_user():
mock_db = MagicMock()
service = UserService(mock_db)
service.create_user("Bob", "bob@test.com")
mock_db.execute.assert_called_once()
mock_db.commit.assert_called_once()
Test Organization
conftest.py — Shared Fixtures
# conftest.py (in test directory)
import pytest
import requests
@pytest.fixture
def api_client():
"""Provide test client for API testing."""
from flask import Flask
from flask_testing import TestCase
app = Flask(__name__)
app.config['TESTING'] = True
with app.test_client() as client:
yield client
@pytest.fixture
def sample_user():
"""Provide sample user data."""
return {
"name": "Test User",
"email": "test@example.com",
"age": 25
}
@pytest.fixture(autouse=True)
def reset_database():
"""Reset database before each test."""
# Setup
yield
# Teardown - clean up after test
Test Classes
import pytest
class TestUserAPI:
"""Test suite for User API endpoints."""
def test_create_user(self, api_client, sample_user):
response = api_client.post('/users', json=sample_user)
assert response.status_code == 201
data = response.get_json()
assert data['name'] == sample_user['name']
def test_get_user(self, api_client):
# First create a user
api_client.post('/users', json={"name": "Alice", "email": "a@b.com"})
response = api_client.get('/users/1')
assert response.status_code == 200
assert response.get_json()['name'] == 'Alice'
def test_get_nonexistent_user(self, api_client):
response = api_client.get('/users/999')
assert response.status_code == 404
def test_delete_user(self, api_client):
api_client.post('/users', json={"name": "Bob", "email": "b@b.com"})
response = api_client.delete('/users/1')
assert response.status_code == 200
Test Discovery
# File naming conventions:
# test_*.py — pytest discovers these files
# *_test.py — also discovered
# Functions must start with test_
# Classes must start with Test
# Methods must start with test_
# Running tests:
# pytest — run all tests
# pytest test_calculator.py — run specific file
# pytest -v — verbose output
# pytest -k "test_add" — run tests matching pattern
# pytest -x — stop on first failure
# pytest --tb=short — shorter tracebacks
# pytest -m slow — run tests marked as slow
Test Coverage
Measure how much of your code is tested:
# Install: pip install pytest-cov
# Run: pytest --cov=your_module --cov-report=html
def calculate_grade(score):
"""Calculate letter grade from numeric score."""
if score >= 90:
return 'A'
elif score >= 80:
return 'B'
elif score >= 70:
return 'C'
elif score >= 60:
return 'D'
else:
return 'F'
# test with coverage
def test_grade_a():
assert calculate_grade(95) == 'A'
def test_grade_b():
assert calculate_grade(85) == 'B'
def test_grade_c():
assert calculate_grade(75) == 'C'
# Missing tests for D and F grades — coverage report will show this
Test-Driven Development (TDD)
Write tests before implementation:
# Step 1: Write failing test
def test_fizzbuzz():
assert fizzbuzz(1) == "1"
assert fizzbuzz(2) == "2"
assert fizzbuzz(3) == "Fizz"
assert fizzbuzz(5) == "Buzz"
assert fizzbuzz(15) == "FizzBuzz"
assert fizzbuzz(7) == "7"
# Step 2: Run test — it fails
# Step 3: Write minimal code to pass
def fizzbuzz(n):
if n % 15 == 0:
return "FizzBuzz"
elif n % 3 == 0:
return "Fizz"
elif n % 5 == 0:
return "Buzz"
else:
return str(n)
# Step 4: Run test — it passes
# Step 5: Refactor if needed, ensure tests still pass
TDD Cycle
# RED -> GREEN -> REFACTOR
# 1. RED: Write a failing test
def test_empty_string_returns_zero():
assert add("") == 0
# 2. GREEN: Write minimal code to pass
def add(numbers):
if numbers == "":
return 0
return int(numbers)
# 3. REFACTOR: Improve code while keeping tests green
def add(numbers):
"""Add numbers from a string."""
if not numbers:
return 0
return sum(int(n) for n in numbers.split(","))
Real-World Examples
Example 1: Testing a REST API
import pytest
from flask import Flask, jsonify, request
# app.py
def create_app():
app = Flask(__name__)
users = {}
@app.route('/users', methods=['GET'])
def get_users():
return jsonify(list(users.values()))
@app.route('/users/<int:user_id>', methods=['GET'])
def get_user(user_id):
user = users.get(user_id)
if user:
return jsonify(user)
return jsonify({"error": "Not found"}), 404
@app.route('/users', methods=['POST'])
def create_user():
data = request.get_json()
user_id = len(users) + 1
user = {"id": user_id, **data}
users[user_id] = user
return jsonify(user), 201
return app
# test_app.py
@pytest.fixture
def client():
app = create_app()
app.config['TESTING'] = True
with app.test_client() as client:
yield client
def test_get_users_empty(client):
response = client.get('/users')
assert response.status_code == 200
assert response.get_json() == []
def test_create_and_get_user(client):
# Create user
response = client.post('/users', json={"name": "Alice"})
assert response.status_code == 201
user = response.get_json()
assert user['name'] == 'Alice'
# Get user
response = client.get(f"/users/{user['id']}")
assert response.status_code == 200
assert response.get_json()['name'] == 'Alice'
def test_get_nonexistent_user(client):
response = client.get('/users/999')
assert response.status_code == 404
Example 2: Testing Database Code
import pytest
import sqlite3
class UserRepository:
def __init__(self, db):
self.db = db
def create(self, name, email):
cursor = self.db.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
(name, email)
)
self.db.commit()
return cursor.lastrowid
def get_by_id(self, user_id):
return self.db.execute(
"SELECT * FROM users WHERE id = ?", (user_id,)
).fetchone()
def get_all(self):
return self.db.execute("SELECT * FROM users").fetchall()
def delete(self, user_id):
self.db.execute("DELETE FROM users WHERE id = ?", (user_id,))
self.db.commit()
@pytest.fixture
def db():
conn = sqlite3.connect(':memory:')
conn.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT,
email TEXT UNIQUE
)
''')
yield conn
conn.close()
@pytest.fixture
def repo(db):
return UserRepository(db)
def test_create_user(repo):
user_id = repo.create("Alice", "alice@test.com")
assert user_id is not None
def test_get_user(repo):
user_id = repo.create("Bob", "bob@test.com")
user = repo.get_by_id(user_id)
assert user['name'] == 'Bob'
assert user['email'] == 'bob@test.com'
def test_get_all_users(repo):
repo.create("Alice", "a@test.com")
repo.create("Bob", "b@test.com")
users = repo.get_all()
assert len(users) == 2
def test_delete_user(repo):
user_id = repo.create("Charlie", "c@test.com")
repo.delete(user_id)
user = repo.get_by_id(user_id)
assert user is None
Common Mistakes
| Mistake | Problem | Solution |
|---|---|---|
| Tests that depend on order | Fragile tests | Each test should be independent |
| Not cleaning up state | Tests affect each other | Use fixtures for setup/teardown |
| Testing implementation details | Brittle tests | Test behavior, not implementation |
| Not using fixtures | Duplicated setup code | Extract common setup to fixtures |
| Ignoring edge cases | Bugs in production | Test boundary conditions |
| Testing too much at once | Hard to diagnose failures | One assertion per test concept |
Best Practices
# 1. Follow Arrange-Act-Assert pattern
def test_add():
# Arrange
a, b = 2, 3
# Act
result = add(a, b)
# Assert
assert result == 5
# 2. Use descriptive test names
def test_add_returns_sum_of_two_positive_integers():
assert add(2, 3) == 5
# 3. Test edge cases
def test_add_with_zeros():
assert add(0, 0) == 0
assert add(0, 5) == 5
assert add(5, 0) == 5
# 4. Use pytest.raises for exception testing
def test_divide_by_zero():
with pytest.raises(ValueError):
divide(10, 0)
# 5. Keep tests fast — mock external dependencies
@patch('requests.get')
def test_api_call(mock_get):
mock_get.return_value.json.return_value = {"data": "value"}
result = fetch_data()
assert result == {"data": "value"}
Key Takeaways
- Name test files
test_*.pyand functionstest_*for pytest discovery - Use fixtures for reusable setup/teardown — they're injected automatically by name
@pytest.mark.parametrizeenables data-driven tests with multiple inputs- Mock external dependencies (APIs, databases) to isolate units under test
- Follow Arrange-Act-Assert pattern for clear, readable tests
- Use
pytest.raises()for exception testing - Aim for high coverage but prioritize testing critical paths and edge cases