如何解决如何在 Python 中使用默认参数重载函数
from typing import Iterable,List,Optional,overload,Literal,Union,Tuple,Any
import sqlite3
@overload
def query_db(
query: str,params: Optional[Iterable],as_tuple: Literal[False]
) -> List[sqlite3.Row]:
...
@overload
def query_db(
query: str,as_tuple: Literal[True]
) -> List[Tuple[Any,...]]:
...
def query_db(
query: str,params: Optional[Iterable] = None,as_tuple: bool = False
) -> Union[List[sqlite3.Row],List[Tuple[Any,...]]]:
"""Run a query against the given db.
If params is not None,securely construct a query from the given
query string and params.
"""
with sqlite3.connect("/dummy.sqlite") as con:
if not as_tuple:
con.row_factory = sqlite3.Row
if params is None:
rows = con.execute(query).fetchall()
else:
rows = con.execute(query,params).fetchall()
return rows
a = query_db("SELECT test_column FROM test_table")
a[0]["test_column"]
我不知道如何进行类型检查。
如果我不添加重载,mypy 会抱怨我可能会使用 str
索引对元组进行索引。
as_tuple
参数默认为 false,因此当不向函数提供第二个和第三个参数时,mypy 应该能够推断出我正在使用第一个重载(因为实际实现具有默认参数) .
然而,实际发生的是 mypy 抱怨提供的重载都不匹配,因为它认为我还需要提供最后两个参数。
当我只是将默认参数复制粘贴到每个重载时,mypy 会抱怨我无法将 False
分配给 as_tuple: Literal[True]
。
有没有办法让它在运行时检查它的工作方式? 我真的不想修改实际签名,因为该函数在我们的测试中被广泛使用。
解决方法
好的,我找到了一个 open issue for this on mypy。
目前的解决方案显然是注释所有可能的显式参数组合,在我的情况下导致:
@overload
def query_db(
query: str,params: Optional[Iterable],as_tuple: Literal[False]
) -> List[sqlite3.Row]:
...
@overload
def query_db(
query: str,as_tuple: Literal[True]
) -> List[Tuple[Any,...]]:
...
@overload
def query_db(
query: str,params: Optional[Iterable]
) -> List[sqlite3.Row]:
...
@overload
def query_db(
query: str,*,...]]:
...
@overload
def query_db(
query: str
) -> List[sqlite3.Row]:
...
def query_db(
query: str,params: Optional[Iterable] = None,as_tuple: bool = False
) -> Union[List[sqlite3.Row],List[Tuple[Any,...]]]:
...
,
如果您让某些重载中的参数采用默认值,那么您就不需要那么多重载。当您将布尔值传递给 as_tuple
时,您可能还需要额外的重载:
from typing import Iterable,List,Optional,overload,Literal,Union,Tuple,Any
import sqlite3
@overload
def query_db(
query: str,params: Optional[Iterable]=...,as_tuple: Literal[False]=...
) -> List[sqlite3.Row]:
...
@overload
def query_db(
query: str,...]]:
...
@overload
def query_db(
query: str,...]]:
...
@overload
def query_db(
query: str,as_tuple: bool=...
) -> Union[List[sqlite3.Row],...]]]:
...
def query_db(
query: str,...]]]:
"""Run a query against the given db.
If params is not None,securely construct a query from the given
query string and params.
"""
with sqlite3.connect("/dummy.sqlite") as con:
if not as_tuple:
con.row_factory = sqlite3.Row
if params is None:
rows = con.execute(query).fetchall()
else:
rows = con.execute(query,params).fetchall()
return rows
query: str
params: Optional[Iterable]
as_tuple: bool
reveal_type(query_db(query,params,as_tuple=True))
reveal_type(query_db(query,params))
reveal_type(query_db(query))
reveal_type(query_db(query,as_tuple=False))
reveal_type(query_db(query,as_tuple=as_tuple))
reveal_type(query_db(query,as_tuple=as_tuple))
运行这个给出:
main.py:51: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:52: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:53: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:54: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:55: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:56: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:57: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row],builtins.list[builtins.tuple[Any]]]'
main.py:58: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row],builtins.list[builtins.tuple[Any]]]'
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。