微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在python单元测试中的导入模块中模拟全局变量?

如何解决如何在python单元测试中的导入模块中模拟全局变量?

我有一个特殊的情况,如下面的帖子!

Post

但是区别是我的数据库模块还具有一个需要模拟的全局变量

    """Initialzation"""
from database.db_init import DB
from common.common_util import CommonUtil as util

db_username = util.get_value_from_ssm_parameter_store(
    '/_username')
db_password = util.get_value_from_ssm_parameter_store(
    '/rds/_password')
db_host = DB.get_rds_host()


def create_session():
    """this will create a db session"""
    db = DB(user=db_username,password=db_password,host=db_host,database='test')
    Session = db.getSession()
    session = Session()
    return session

我模拟了get_value_from_ssm_parameter_store()函数,甚至将变量模拟为@patch('database.db_username','test'),还尝试了unittest类中的database.db_username = Magicmock(return_value ='test') !

但是电话仍然在打给aws。有人可以帮助我模拟导入模块中的全局变量吗?

解决方法

由于 util.get_value_from_ssm_parameter_store()DB.get_rds_host() 方法在 session.py 的模块范围内执行。

您应该在从 create_session 模块导入 session 函数之前修补这些方法。

例如

common/common_util.py

class CommonUtil:
    @staticmethod
    def get_value_from_ssm_parameter_store(key):
        print('call real aws')

database/db_init.py

class Session:
    pass


class DB:
    @staticmethod
    def get_rds_host():
        return '127.0.0.1'

    def __init__(self,user,password,host,database) -> None:
        pass

    def getSession(self):
        return Session

session.py

from database.db_init import DB
from common.common_util import CommonUtil as util


db_username = util.get_value_from_ssm_parameter_store(
    '/_username')
db_password = util.get_value_from_ssm_parameter_store(
    '/rds/_password')
db_host = DB.get_rds_host()


def create_session():
    """this will create a db session"""
    db = DB(user=db_username,password=db_password,host=db_host,database='test')
    Session = db.getSession()
    session = Session()
    return session

test_session.py

import unittest
from unittest.mock import patch,Mock,call
from common.common_util import CommonUtil as util
from database.db_init import DB


def get_value_from_ssm_parameter_store_side_effect(key):
    if key == '/_username':
        return 'teresa teng'
    if key == '/rds/_password':
        return '123456'


original_get_value_from_ssm_parameter_store = util.get_value_from_ssm_parameter_store
original_get_rds_host = DB.get_rds_host


util.get_value_from_ssm_parameter_store = Mock(side_effect=get_value_from_ssm_parameter_store_side_effect)
DB.get_rds_host = Mock(return_value='192.168.1.1')


class TestSession(unittest.TestCase):
    @patch('session.DB',autospec=True)
    def test_create_session(self,mock_DB):
        from session import create_session
        db_instance = mock_DB.return_value
        mock_session = Mock()
        db_instance.getSession.return_value = mock_session
        create_session()
        mock_DB.assert_called_once_with('teresa teng','123456','192.168.1.1','test')
        util.get_value_from_ssm_parameter_store.assert_has_calls([call('/_username'),call('/rds/_password')])
        db_instance.getSession.assert_called_once()

        # restore mock
        util.get_value_from_ssm_parameter_store = original_get_value_from_ssm_parameter_store
        DB.get_rds_host = original_get_rds_host


if __name__ == '__main__':
    unittest.main()

单元测试结果:

 ⚡  coverage run /Users/dulin/workspace/github.com/mrdulin/python-codelab/src/stackoverflow/64329623/test_session.py && coverage report -m --include='./src/**'
.
----------------------------------------------------------------------
Ran 1 test in 0.011s

OK
Name                                               Stmts   Miss  Cover   Missing
--------------------------------------------------------------------------------
src/stackoverflow/64329623/common/common_util.py       4      1    75%   4
src/stackoverflow/64329623/database/db_init.py        10      3    70%   8,11,14
src/stackoverflow/64329623/session.py                 10      0   100%
src/stackoverflow/64329623/test_session.py            28      0   100%
--------------------------------------------------------------------------------
TOTAL                                                 52      4    92%

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?