Python 的 mock 模块简介与单元测试使用指南※
简介※
unittest.mock(简称 mock)是 Python 标准库中用于替换、监视和断言依赖行为的工具。常用于单元测试中隔离外部依赖(网络、数据库、时间、IO 等),使测试只关注被测代码逻辑。
核心思想:用假对象替换真实依赖,控制其行为,验证交互是否符合预期。
常用 API※
Mock/MagicMock:创建可配置的假对象- Mock:普通模拟对象
- MagicMock:支持特殊方法(如
__enter__、__iter__),适合模拟上下文管理器、迭代器等
patch/patch.object:临时替换对象- patch:替换模块级别的对象(函数、类、变量)
- patch.object:替换对象的属性或方法
- 可作为装饰器或上下文管理器使用
return_value:设置调用返回值(单一值)side_effect:设置调用副作用- 可以是异常(模拟错误)
- 可以是列表(多次调用返回不同值)
- 可以是函数(动态计算返回值)
- 断言方法:验证调用情况
assert_called():至少被调用一次assert_called_once():恰好被调用一次assert_called_with(*args, **kwargs):最后一次调用的参数assert_called_once_with(*args, **kwargs):只调用一次且参数匹配assert_not_called():从未被调用
create_autospec/autospec=True:创建带签名约束的 mock,防止错误调用AsyncMock:用于异步函数的 mock(Python 3.8+)
关键概念:patch 路径的正确选择※
⚠️ 常见错误:新手经常搞混 patch 的路径,导致 mock 不生效。
规则:patch 要替换的是"被测代码使用时的路径",而非"原始定义的路径"
# 文件结构
# myapp/email.py
def send_email(to, subject, body):
print(f"发送邮件到 {to}")
# myapp/service.py
from myapp.email import send_email # 这里导入了 send_email
def notify_user(user):
send_email(user.email, "Hi", "Welcome") # 这里使用的是 service 模块的 send_email
return True
# tests/test_service.py
from unittest.mock import patch
from myapp import service
# ✅ 正确:patch 被测模块中的引用
@patch('myapp.service.send_email')
def test_correct(mock_send):
# 这会成功 mock
pass
# ❌ 错误:patch 原始定义位置
@patch('myapp.email.send_email')
def test_wrong(mock_send):
# 这不会生效,因为 service.py 中的 send_email 不是从 email 模块调用的
pass
记忆技巧:"在哪用就 patch 哪",即 patch('使用者模块.被用对象')
示例 1:基本 patch 装饰器用法※
# myapp/email.py
def send_email(to, subject, body):
# 实际发送邮件的实现(可能需要网络、SMTP 配置等)
import smtplib
# ...真实的邮件发送逻辑...
pass
# myapp/service.py
from myapp.email import send_email
def notify_user(user):
"""通知用户,发送欢迎邮件"""
if not user.email:
return False
send_email(user.email, "Hi", "Welcome")
return True
# tests/test_service.py
import unittest
from unittest.mock import patch
from myapp import service
class User:
def __init__(self, email):
self.email = email
class ServiceTest(unittest.TestCase):
@patch('myapp.service.send_email') # 替换 service 模块中的 send_email 引用
def test_notify_user_sends_email(self, mock_send):
"""测试:用户有邮箱时应该发送邮件"""
# 准备测试数据
user = User('a@example.com')
# 执行被测方法
result = service.notify_user(user)
# 验证结果
self.assertTrue(result) # 返回 True
# 验证交互:send_email 被正确调用
mock_send.assert_called_once_with('a@example.com', 'Hi', 'Welcome')
@patch('myapp.service.send_email')
def test_notify_user_without_email(self, mock_send):
"""测试:用户没有邮箱时不发送邮件"""
user = User(None)
result = service.notify_user(user)
self.assertFalse(result) # 返回 False
mock_send.assert_not_called() # send_email 未被调用
if __name__ == '__main__':
unittest.main()
示例 2:patch.object 与 return_value※
from unittest.mock import patch, Mock
import unittest
class Database:
def connect(self):
"""连接数据库"""
pass
def query(self, q):
"""执行查询"""
pass
def get_user_count(db: Database):
"""获取用户总数"""
db.connect()
return db.query('select count(*) from users')
class DBTest(unittest.TestCase):
def test_get_user_count_with_mock(self):
"""使用 Mock 对象测试数据库查询"""
# 创建符合 Database 接口的假对象
fake_db = Mock(spec=Database)
# 设置 query 方法的返回值
fake_db.query.return_value = 42
# 调用被测函数
result = get_user_count(fake_db)
# 验证
self.assertEqual(result, 42)
fake_db.connect.assert_called_once() # connect 被调用
fake_db.query.assert_called_once_with('select count(*) from users')
@patch.object(Database, 'query', return_value=100)
@patch.object(Database, 'connect')
def test_with_patch_object(self, mock_connect, mock_query):
"""使用 patch.object 测试真实 Database 类"""
db = Database()
result = get_user_count(db)
self.assertEqual(result, 100)
mock_connect.assert_called_once()
mock_query.assert_called_once()
示例 3:side_effect 模拟异常与多次返回※
from unittest.mock import Mock
# 1. 模拟抛出异常
m = Mock()
m.side_effect = Exception('数据库连接失败')
# m() 会抛出异常
# 2. 多次调用返回不同值
m = Mock()
m.side_effect = [1, 2, 3]
print(m()) # 1
print(m()) # 2
print(m()) # 3
# 第 4 次调用会抛出 StopIteration
# 3. 使用函数动态计算
def dynamic_return(x):
return x * 2
m = Mock()
m.side_effect = dynamic_return
print(m(5)) # 10
print(m(10)) # 20
# 4. 实际测试示例
import unittest
from unittest.mock import patch
def fetch_data(api_client):
"""尝试获取数据,失败时重试"""
try:
return api_client.get('/data')
except Exception:
return api_client.get('/data') # 重试一次
class FetchTest(unittest.TestCase):
def test_retry_on_failure(self):
"""测试:第一次失败,第二次成功"""
mock_client = Mock()
# 第一次调用抛异常,第二次返回数据
mock_client.get.side_effect = [Exception('timeout'), {'result': 'ok'}]
result = fetch_data(mock_client)
self.assertEqual(result, {'result': 'ok'})
self.assertEqual(mock_client.get.call_count, 2) # 调用了 2 次
示例 4:异步函数(AsyncMock)※
import asyncio
from unittest.mock import AsyncMock, patch
import unittest
async def fetch(url):
"""异步网络请求(需要真实网络)"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return await resp.json()
async def get_title(url):
"""获取页面标题"""
data = await fetch(url)
return data['title']
class AsyncTest(unittest.IsolatedAsyncioTestCase):
@patch('__main__.fetch', new_callable=AsyncMock)
async def test_get_title(self, mock_fetch):
"""测试异步函数"""
# 设置 async 函数的返回值
mock_fetch.return_value = {'title': 'Python Mock 教程'}
# 调用异步函数
title = await get_title('http://example.com')
# 验证
self.assertEqual(title, 'Python Mock 教程')
mock_fetch.assert_called_once_with('http://example.com')
@patch('__main__.fetch', new_callable=AsyncMock)
async def test_fetch_error(self, mock_fetch):
"""测试异步函数抛异常"""
mock_fetch.side_effect = Exception('网络错误')
with self.assertRaises(Exception):
await get_title('http://example.com')
示例 5:上下文管理器与 with 语句※
from unittest.mock import patch, mock_open
def read_config(filename):
"""读取配置文件"""
with open(filename) as f:
return f.read()
def test_read_config():
"""测试文件读取(不实际创建文件)"""
fake_content = "api_key=12345\napi_secret=abcde"
# mock_open 模拟文件操作
with patch('builtins.open', mock_open(read_data=fake_content)):
content = read_config('config.txt')
assert content == fake_content
# 也可以作为装饰器
@patch('builtins.open', mock_open(read_data='test'))
def test_as_decorator():
content = read_config('any.txt')
assert content == 'test'
常见错误与解决方案※
错误 1:patch 路径错误※
# ❌ 错误
@patch('os.path.exists') # 如果被测代码是 from os.path import exists
def test_wrong(mock_exists):
pass
# ✅ 正确
@patch('mymodule.exists') # patch 被测模块中的引用
def test_correct(mock_exists):
pass
错误 2:装饰器顺序错误※
# ❌ 错误:参数顺序与装饰器相反
@patch('module.func_a')
@patch('module.func_b')
def test(mock_a, mock_b): # 应该是 mock_b, mock_a
pass
# ✅ 正确:从下到上对应参数
@patch('module.func_a')
@patch('module.func_b')
def test(mock_b, mock_a): # 最下面的装饰器对应第一个参数
pass
错误 3:忘记使用 autospec※
# ❌ 危险:可以调用不存在的方法
mock = Mock()
mock.some_nonexistent_method() # 不会报错!
# ✅ 安全:限制为真实接口
from mymodule import RealClass
mock = Mock(spec=RealClass)
# mock.some_nonexistent_method() # 会抛出 AttributeError
实用建议※
- patch 路径选择:按被测代码的导入路径替换(
patch('使用模块.对象')),而非原实现所在模块 - 优先使用 autospec:使用
spec、autospec=True或create_autospec防止调用签名错误 - 只 mock 边界:只 mock 外部依赖(IO、网络、数据库、时间),不要过度 mock 内部逻辑
- 明确验证交互:使用
assert_called_*系列方法验证期望的调用行为 - 测试异常场景:使用
side_effect模拟错误情况,确保代码有容错处理 - 保持测试简单:一个测试只验证一个行为,避免过于复杂的 mock 设置
运行测试※
# 方式 1:使用 unittest
python -m unittest tests/test_service.py
# 方式 2:使用 pytest(推荐,输出更友好)
pip install pytest
pytest tests/
# 方式 3:单个测试文件
python -m unittest tests.test_service.ServiceTest.test_notify_user_sends_email
# 查看详细输出
python -m unittest -v tests/
pytest -v tests/
完整实战示例※
# myapp/weather.py
import requests
def get_weather(city):
"""获取城市天气(需要网络请求)"""
resp = requests.get(f'https://api.weather.com/{city}')
return resp.json()['temperature']
def should_bring_umbrella(city):
"""判断是否需要带伞"""
temp = get_weather(city)
return temp < 20 # 低于 20 度建议带伞
# tests/test_weather.py
import unittest
from unittest.mock import patch
from myapp.weather import should_bring_umbrella
class WeatherTest(unittest.TestCase):
@patch('myapp.weather.get_weather')
def test_bring_umbrella_when_cold(self, mock_get_weather):
"""测试:温度低时建议带伞"""
mock_get_weather.return_value = 15
result = should_bring_umbrella('Beijing')
self.assertTrue(result)
mock_get_weather.assert_called_once_with('Beijing')
@patch('myapp.weather.get_weather')
def test_no_umbrella_when_warm(self, mock_get_weather):
"""测试:温度高时不需要带伞"""
mock_get_weather.return_value = 25
result = should_bring_umbrella('Shanghai')
self.assertFalse(result)