mock

Python mock 模块简介与单元测试示例

Python 的 mock 模块简介与单元测试使用指南

简介

unittest.mock(简称 mock)是 Python 标准库中用于替换、监视和断言依赖行为的工具。常用于单元测试中隔离外部依赖(网络、数据库、时间、IO 等),使测试只关注被测代码逻辑。

核心思想:用假对象替换真实依赖,控制其行为,验证交互是否符合预期。

常用 API

  • Mock / MagicMock:创建可配置的假对象
    • Mock:普通模拟对象
    • MagicMock:支持特殊方法(如 __enter____iter__),适合模拟上下文管理器、迭代器等
  • patch / patch.object:临时替换对象
    • patch:替换模块级别的对象(函数、类、变量)
    • patch.object:替换对象的属性或方法
    • 可作为装饰器或上下文管理器使用
  • return_value:设置调用返回值(单一值)
  • side_effect:设置调用副作用
    • 可以是异常(模拟错误)
    • 可以是列表(多次调用返回不同值)
    • 可以是函数(动态计算返回值)
  • 断言方法:验证调用情况
    • assert_called():至少被调用一次
    • assert_called_once():恰好被调用一次
    • assert_called_with(*args, **kwargs):最后一次调用的参数
    • assert_called_once_with(*args, **kwargs):只调用一次且参数匹配
    • assert_not_called():从未被调用
  • create_autospec / autospec=True:创建带签名约束的 mock,防止错误调用
  • AsyncMock:用于异步函数的 mock(Python 3.8+)

关键概念:patch 路径的正确选择

⚠️ 常见错误:新手经常搞混 patch 的路径,导致 mock 不生效。

规则:patch 要替换的是"被测代码使用时的路径",而非"原始定义的路径"

# 文件结构
# myapp/email.py
def send_email(to, subject, body):
    print(f"发送邮件到 {to}")

# myapp/service.py
from myapp.email import send_email  # 这里导入了 send_email

def notify_user(user):
    send_email(user.email, "Hi", "Welcome")  # 这里使用的是 service 模块的 send_email
    return True

# tests/test_service.py
from unittest.mock import patch
from myapp import service

# ✅ 正确:patch 被测模块中的引用
@patch('myapp.service.send_email')  
def test_correct(mock_send):
    # 这会成功 mock
    pass

# ❌ 错误:patch 原始定义位置
@patch('myapp.email.send_email')  
def test_wrong(mock_send):
    # 这不会生效,因为 service.py 中的 send_email 不是从 email 模块调用的
    pass
记忆技巧:"在哪用就 patch 哪",即 patch('使用者模块.被用对象')

示例 1:基本 patch 装饰器用法

# myapp/email.py
def send_email(to, subject, body):
    # 实际发送邮件的实现(可能需要网络、SMTP 配置等)
    import smtplib
    # ...真实的邮件发送逻辑...
    pass

# myapp/service.py
from myapp.email import send_email

def notify_user(user):
    """通知用户,发送欢迎邮件"""
    if not user.email:
        return False
    send_email(user.email, "Hi", "Welcome")
    return True
# tests/test_service.py
import unittest
from unittest.mock import patch
from myapp import service

class User: 
    def __init__(self, email): 
        self.email = email

class ServiceTest(unittest.TestCase):
    
    @patch('myapp.service.send_email')  # 替换 service 模块中的 send_email 引用
    def test_notify_user_sends_email(self, mock_send):
        """测试:用户有邮箱时应该发送邮件"""
        # 准备测试数据
        user = User('a@example.com')
        
        # 执行被测方法
        result = service.notify_user(user)
        
        # 验证结果
        self.assertTrue(result)  # 返回 True
        
        # 验证交互:send_email 被正确调用
        mock_send.assert_called_once_with('a@example.com', 'Hi', 'Welcome')
    
    @patch('myapp.service.send_email')
    def test_notify_user_without_email(self, mock_send):
        """测试:用户没有邮箱时不发送邮件"""
        user = User(None)
        
        result = service.notify_user(user)
        
        self.assertFalse(result)  # 返回 False
        mock_send.assert_not_called()  # send_email 未被调用

if __name__ == '__main__':
    unittest.main()

示例 2:patch.object 与 return_value

from unittest.mock import patch, Mock
import unittest

class Database:
    def connect(self): 
        """连接数据库"""
        pass
    
    def query(self, q): 
        """执行查询"""
        pass

def get_user_count(db: Database):
    """获取用户总数"""
    db.connect()
    return db.query('select count(*) from users')

class DBTest(unittest.TestCase):
    
    def test_get_user_count_with_mock(self):
        """使用 Mock 对象测试数据库查询"""
        # 创建符合 Database 接口的假对象
        fake_db = Mock(spec=Database)
        
        # 设置 query 方法的返回值
        fake_db.query.return_value = 42
        
        # 调用被测函数
        result = get_user_count(fake_db)
        
        # 验证
        self.assertEqual(result, 42)
        fake_db.connect.assert_called_once()  # connect 被调用
        fake_db.query.assert_called_once_with('select count(*) from users')
    
    @patch.object(Database, 'query', return_value=100)
    @patch.object(Database, 'connect')
    def test_with_patch_object(self, mock_connect, mock_query):
        """使用 patch.object 测试真实 Database 类"""
        db = Database()
        result = get_user_count(db)
        
        self.assertEqual(result, 100)
        mock_connect.assert_called_once()
        mock_query.assert_called_once()

示例 3:side_effect 模拟异常与多次返回

from unittest.mock import Mock

# 1. 模拟抛出异常
m = Mock()
m.side_effect = Exception('数据库连接失败')
# m() 会抛出异常

# 2. 多次调用返回不同值
m = Mock()
m.side_effect = [1, 2, 3]
print(m())  # 1
print(m())  # 2
print(m())  # 3
# 第 4 次调用会抛出 StopIteration

# 3. 使用函数动态计算
def dynamic_return(x):
    return x * 2

m = Mock()
m.side_effect = dynamic_return
print(m(5))   # 10
print(m(10))  # 20

# 4. 实际测试示例
import unittest
from unittest.mock import patch

def fetch_data(api_client):
    """尝试获取数据,失败时重试"""
    try:
        return api_client.get('/data')
    except Exception:
        return api_client.get('/data')  # 重试一次

class FetchTest(unittest.TestCase):
    
    def test_retry_on_failure(self):
        """测试:第一次失败,第二次成功"""
        mock_client = Mock()
        # 第一次调用抛异常,第二次返回数据
        mock_client.get.side_effect = [Exception('timeout'), {'result': 'ok'}]
        
        result = fetch_data(mock_client)
        
        self.assertEqual(result, {'result': 'ok'})
        self.assertEqual(mock_client.get.call_count, 2)  # 调用了 2 次

示例 4:异步函数(AsyncMock)

import asyncio
from unittest.mock import AsyncMock, patch
import unittest

async def fetch(url):
    """异步网络请求(需要真实网络)"""
    import aiohttp
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as resp:
            return await resp.json()

async def get_title(url):
    """获取页面标题"""
    data = await fetch(url)
    return data['title']

class AsyncTest(unittest.IsolatedAsyncioTestCase):
    
    @patch('__main__.fetch', new_callable=AsyncMock)
    async def test_get_title(self, mock_fetch):
        """测试异步函数"""
        # 设置 async 函数的返回值
        mock_fetch.return_value = {'title': 'Python Mock 教程'}
        
        # 调用异步函数
        title = await get_title('http://example.com')
        
        # 验证
        self.assertEqual(title, 'Python Mock 教程')
        mock_fetch.assert_called_once_with('http://example.com')
    
    @patch('__main__.fetch', new_callable=AsyncMock)
    async def test_fetch_error(self, mock_fetch):
        """测试异步函数抛异常"""
        mock_fetch.side_effect = Exception('网络错误')
        
        with self.assertRaises(Exception):
            await get_title('http://example.com')

示例 5:上下文管理器与 with 语句

from unittest.mock import patch, mock_open

def read_config(filename):
    """读取配置文件"""
    with open(filename) as f:
        return f.read()

def test_read_config():
    """测试文件读取(不实际创建文件)"""
    fake_content = "api_key=12345\napi_secret=abcde"
    
    # mock_open 模拟文件操作
    with patch('builtins.open', mock_open(read_data=fake_content)):
        content = read_config('config.txt')
        assert content == fake_content

# 也可以作为装饰器
@patch('builtins.open', mock_open(read_data='test'))
def test_as_decorator():
    content = read_config('any.txt')
    assert content == 'test'

常见错误与解决方案

错误 1:patch 路径错误

# ❌ 错误
@patch('os.path.exists')  # 如果被测代码是 from os.path import exists
def test_wrong(mock_exists):
    pass

# ✅ 正确
@patch('mymodule.exists')  # patch 被测模块中的引用
def test_correct(mock_exists):
    pass

错误 2:装饰器顺序错误

# ❌ 错误:参数顺序与装饰器相反
@patch('module.func_a')
@patch('module.func_b')
def test(mock_a, mock_b):  # 应该是 mock_b, mock_a
    pass

# ✅ 正确:从下到上对应参数
@patch('module.func_a')
@patch('module.func_b')
def test(mock_b, mock_a):  # 最下面的装饰器对应第一个参数
    pass

错误 3:忘记使用 autospec

# ❌ 危险:可以调用不存在的方法
mock = Mock()
mock.some_nonexistent_method()  # 不会报错!

# ✅ 安全:限制为真实接口
from mymodule import RealClass
mock = Mock(spec=RealClass)
# mock.some_nonexistent_method()  # 会抛出 AttributeError

实用建议

  • patch 路径选择:按被测代码的导入路径替换(patch('使用模块.对象')),而非原实现所在模块
  • 优先使用 autospec:使用 specautospec=Truecreate_autospec 防止调用签名错误
  • 只 mock 边界:只 mock 外部依赖(IO、网络、数据库、时间),不要过度 mock 内部逻辑
  • 明确验证交互:使用 assert_called_* 系列方法验证期望的调用行为
  • 测试异常场景:使用 side_effect 模拟错误情况,确保代码有容错处理
  • 保持测试简单:一个测试只验证一个行为,避免过于复杂的 mock 设置

运行测试

# 方式 1:使用 unittest
python -m unittest tests/test_service.py

# 方式 2:使用 pytest(推荐,输出更友好)
pip install pytest
pytest tests/

# 方式 3:单个测试文件
python -m unittest tests.test_service.ServiceTest.test_notify_user_sends_email

# 查看详细输出
python -m unittest -v tests/
pytest -v tests/

完整实战示例

# myapp/weather.py
import requests

def get_weather(city):
    """获取城市天气(需要网络请求)"""
    resp = requests.get(f'https://api.weather.com/{city}')
    return resp.json()['temperature']

def should_bring_umbrella(city):
    """判断是否需要带伞"""
    temp = get_weather(city)
    return temp < 20  # 低于 20 度建议带伞

# tests/test_weather.py
import unittest
from unittest.mock import patch
from myapp.weather import should_bring_umbrella

class WeatherTest(unittest.TestCase):
    
    @patch('myapp.weather.get_weather')
    def test_bring_umbrella_when_cold(self, mock_get_weather):
        """测试:温度低时建议带伞"""
        mock_get_weather.return_value = 15
        
        result = should_bring_umbrella('Beijing')
        
        self.assertTrue(result)
        mock_get_weather.assert_called_once_with('Beijing')
    
    @patch('myapp.weather.get_weather')
    def test_no_umbrella_when_warm(self, mock_get_weather):
        """测试:温度高时不需要带伞"""
        mock_get_weather.return_value = 25
        
        result = should_bring_umbrella('Shanghai')
        
        self.assertFalse(result)

总结:mock 的核心是"替换"和"验证"。通过控制依赖的行为,让测试专注于业务逻辑本身。掌握 patch 路径、return_value、side_effect 和断言方法,就能应对大部分单元测试场景。

💡 下一步:结合 pytest 的 fixture 和参数化测试,可以让测试代码更简洁优雅。

“您的支持是我持续分享的动力”

微信收款码
微信
支付宝收款码
支付宝

目录