From 39191514cc0d62039dccdc38cff90a5bac383a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Wed, 15 Nov 2023 12:21:51 +0100 Subject: [PATCH] mypy: enable warn_return_any --- pyproject.toml | 2 +- src/capport/api/proxy_fix.py | 5 ++++- src/capport/utils/nft_socket.py | 24 ++++++++++++++++++++---- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b70394..abd45c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ python_version = "3.11" # disallow_any_generics = true # disallow_untyped_defs = true # warn_redundant_casts = true -# warn_return_any = true +warn_return_any = true warn_unused_configs = true # warn_unused_ignores = true # warn_unreachable = true diff --git a/src/capport/api/proxy_fix.py b/src/capport/api/proxy_fix.py index e6121a3..6162e82 100644 --- a/src/capport/api/proxy_fix.py +++ b/src/capport/api/proxy_fix.py @@ -67,7 +67,10 @@ def local_proxy_fix(request: quart.Request): class LocalProxyFixRequestHandler: - def __init__(self, orig_handle_request): + def __init__( + self, + orig_handle_request: typing.Callable[[quart.Request], typing.Awaitable[quart.Response | werkzeug.Response]], + ): self._orig_handle_request = orig_handle_request async def __call__(self, request: quart.Request) -> quart.Response | werkzeug.Response: diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py index 70884db..41e1b66 100644 --- a/src/capport/utils/nft_socket.py +++ b/src/capport/utils/nft_socket.py @@ -50,7 +50,7 @@ _monkey_patch_pyroute2() def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: - msg = msg_class() + msg: _NlMsgBase = msg_class() for key, value in header.items(): msg["header"][key] = value for key, value in fields.items(): @@ -102,16 +102,32 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): with self.begin() as tx: tx.put(msg_type, msg_flags, attrs=attrs, **fields) - def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: + def nft_dump( + self, + msg_type: int, + msg_flags: int = 0, + /, + *, + attrs: dict = {}, + **fields, + ) -> typing.Iterator[nfgen_msg]: msg_flags |= pyroute2.netlink.NLM_F_DUMP return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) - def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: + def nft_get( + self, + msg_type: int, + msg_flags: int = 0, + /, + *, + attrs: dict = {}, + **fields, + ) -> typing.Iterator[nfgen_msg]: msg_class: type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_flags |= pyroute2.netlink.NLM_F_REQUEST msg = _build(msg_class, attrs=attrs, **fields) - return self.nlm_request(msg, msg_type, msg_flags) + return self.nlm_request(msg, msg_type, msg_flags) # type: ignore class NFTTransaction: