diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..5cb7990 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ] +# E402 module level import not at top of file [ usually on purpose. might use individual overrides instead? ] +# E701 multiple statements on one line [ still quite readable in short forms ] +# E713 test for membership should be ‘not in’ [ disagree: want `not a in x` ] +# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ] +# W503 line break before binary operator [ gotta pick one way ] +extend-ignore = E266,E402,E701,E713,E714,W503 +max-line-length = 120 +exclude = *_pb2.py +application-import-names = capport diff --git a/.pycodestyle b/.pycodestyle deleted file mode 100644 index 8f3dbe0..0000000 --- a/.pycodestyle +++ /dev/null @@ -1,11 +0,0 @@ -[pycodestyle] -# E241 multiple spaces after ':' [ want to align stuff ] -# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ] -# E501 line too long [ temporary? can't disable it in certain places.. ] -# E701 multiple statements on one line (colon) [ perfectly readable ] -# E713 test for membership should be ‘not in’ [ disagree: want `not a in x` ] -# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ] -# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ] -ignore = E241,E266,E501,E701,E713,E714,W503 -max-line-length = 120 -exclude = 00*.py,generated.py diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index f05a15a..0000000 --- a/.pylintrc +++ /dev/null @@ -1,20 +0,0 @@ -[MESSAGES CONTROL] - -disable=logging-fstring-interpolation - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=120 - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -[DESIGN] - -# Maximum number of locals for function / method body. -max-locals=20 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=0 diff --git a/pylint b/flake8 similarity index 61% rename from pylint rename to flake8 index d88ee0c..3ae725b 100755 --- a/pylint +++ b/flake8 @@ -12,9 +12,8 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then exit 1 fi -if [ ! -x ./venv/bin/pylint ]; then - # need current pylint to deal with more recent python features - ./venv/bin/pip install pylint +if [ ! -x ./venv/bin/flake8 ]; then + ./venv/bin/pip install flake8 flake8-import-order fi -./venv/bin/pylint src +./venv/bin/flake8 src diff --git a/mypy b/mypy index 095418b..c50d79b 100755 --- a/mypy +++ b/mypy @@ -13,12 +13,30 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then fi if [ ! -x ./venv/bin/mypy ]; then - ./venv/bin/pip install mypy + ./venv/bin/pip install mypy trio-typing[mypy] types-PyYAML types-aiofiles types-colorama types-cryptography types-protobuf types-toml fi site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])') if [ ! -d "${site_pkgs}/trio_typing" ]; then ./venv/bin/pip install trio-typing[mypy] fi +if [ ! -d "${site_pkgs}/yaml-stubs" ]; then + ./venv/bin/pip install types-PyYAML +fi +if [ ! -d "${site_pkgs}/aiofiles-stubs" ]; then + ./venv/bin/pip install types-aiofiles +fi +if [ ! -d "${site_pkgs}/colorama-stubs" ]; then + ./venv/bin/pip install types-colorama +fi +if [ ! -d "${site_pkgs}/cryptography-stubs" ]; then + ./venv/bin/pip install types-cryptography +fi +if [ ! -d "${site_pkgs}/google-stubs" ]; then + ./venv/bin/pip install types-protobuf +fi +if [ ! -d "${site_pkgs}/toml-stubs" ]; then + ./venv/bin/pip install types-toml +fi ./venv/bin/mypy --install-types src diff --git a/pyproject.toml b/pyproject.toml index 374b58c..bd74a36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,3 +4,11 @@ requires = [ "wheel" ] build-backend = "setuptools.build_meta" + +[tool.mypy] +python_version = "3.9" +# warn_return_any = true +warn_unused_configs = true +exclude = [ + '_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary) +] diff --git a/src/capport/api/app_cls.py b/src/capport/api/app_cls.py index e2eba7c..eb5f2c2 100644 --- a/src/capport/api/app_cls.py +++ b/src/capport/api/app_cls.py @@ -4,12 +4,15 @@ import os import os.path import typing +import jinja2 + +import quart.templating + +import quart_trio + import capport.comm.hub import capport.config import capport.utils.ipneigh -import jinja2 -import quart.templating -import quart_trio class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader): diff --git a/src/capport/api/proxy_fix.py b/src/capport/api/proxy_fix.py index c574127..6198e64 100644 --- a/src/capport/api/proxy_fix.py +++ b/src/capport/api/proxy_fix.py @@ -4,13 +4,14 @@ import ipaddress import typing import quart -from werkzeug.http import parse_list_header + import werkzeug +from werkzeug.http import parse_list_header from .app import app -def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str]=()) -> typing.Optional[str]: +def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]: if not value_list: return None values = parse_list_header(value_list) @@ -21,6 +22,8 @@ def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequenc def local_proxy_fix(request: quart.Request): + if not request.remote_addr: + return try: addr = ipaddress.ip_address(request.remote_addr) except ValueError: diff --git a/src/capport/api/setup.py b/src/capport/api/setup.py index b9707b1..898984f 100644 --- a/src/capport/api/setup.py +++ b/src/capport/api/setup.py @@ -3,13 +3,13 @@ from __future__ import annotations import logging import uuid -import capport.database -import capport.comm.hub -import capport.comm.message -import capport.utils.cli -import capport.utils.ipneigh import trio +import capport.comm.hub +import capport.comm.message +import capport.database +import capport.utils.cli +import capport.utils.ipneigh from capport.utils.sd_notify import open_sdnotify from .app import app @@ -19,7 +19,12 @@ _logger = logging.getLogger(__name__) class ApiHubApp(capport.comm.hub.HubApplication): - async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + async def mac_states_changed( + self, + *, + from_peer_id: uuid.UUID, + pending_updates: capport.database.PendingUpdates, + ) -> None: # TODO: support websocket notification updates to clients? pass diff --git a/src/capport/api/views.py b/src/capport/api/views.py index 2368aab..7ebf81b 100644 --- a/src/capport/api/views.py +++ b/src/capport/api/views.py @@ -4,13 +4,15 @@ import ipaddress import logging import typing +import quart + +import trio + import capport.comm.hub import capport.comm.message import capport.database import capport.utils.cli import capport.utils.ipneigh -import quart -import trio from capport import cptypes from .app import app @@ -20,22 +22,27 @@ _logger = logging.getLogger(__name__) def get_client_ip() -> cptypes.IPAddress: + remote_addr = quart.request.remote_addr + if not remote_addr: + quart.abort(500, 'Missing client address') try: - addr = ipaddress.ip_address(quart.request.remote_addr) + addr = ipaddress.ip_address(remote_addr) except ValueError as e: - _logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') + _logger.warning(f'Invalid client address {remote_addr!r}: {e}') quart.abort(500, 'Invalid client address') return addr -async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]: +async def get_client_mac_if_present( + address: typing.Optional[cptypes.IPAddress] = None, +) -> typing.Optional[cptypes.MacAddress]: assert app.my_nc # for mypy if not address: address = get_client_ip() return await app.my_nc.get_neighbor_mac(address) -async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress: +async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress: mac = await get_client_mac_if_present(address) if mac is None: _logger.warning(f"Couldn't find MAC addresss for {address}") @@ -109,7 +116,7 @@ def check_self_origin(): @app.route('/', methods=['GET']) -async def index(missing_accept: bool=False): +async def index(missing_accept: bool = False): state = await user_lookup() if not state.mac: return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept) diff --git a/src/capport/comm/hub.py b/src/capport/comm/hub.py index 4805ef0..8558771 100644 --- a/src/capport/comm/hub.py +++ b/src/capport/comm/hub.py @@ -10,10 +10,12 @@ import struct import typing import uuid -import capport.database -import capport.comm.message import trio +import capport.comm.message +import capport.database + + if typing.TYPE_CHECKING: from ..config import Config @@ -57,7 +59,9 @@ class Channel: peer_hello = (await self.recv_msg()).to_variant() if not isinstance(peer_hello, capport.comm.message.Hello): raise HubConnectionReadError("Expected Hello as first message") - auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)) + auth_succ = \ + (peer_hello.authentication == + self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)) await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) peer_auth = (await self.recv_msg()).to_variant() if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): @@ -159,7 +163,7 @@ class Connection: _logger.warning(f"{self._channel}: send timed out") except ConnectionError as e: _logger.warning(f"{self._channel}: failed sending: {e!r}") - except Exception as e: + except Exception: _logger.exception(f"{self._channel}: failed sending") finally: cancel_scope.cancel() @@ -224,7 +228,7 @@ class Connection: with trio.fail_after(5): peer = await channel.do_handshake() except trio.TooSlowError: - _logger.warning(f"Handshake timed out") + _logger.warning("Handshake timed out") return conn = Connection(hub, channel, peer) await conn._run() @@ -278,7 +282,12 @@ class HubApplication: if _logger.isEnabledFor(logging.DEBUG): _logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}") - async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + async def mac_states_changed( + self, + *, + from_peer_id: uuid.UUID, + pending_updates: capport.database.PendingUpdates, + ) -> None: if _logger.isEnabledFor(logging.DEBUG): _logger.debug(f"Received new states from {from_peer_id}: {pending_updates}") @@ -303,8 +312,8 @@ class Hub: # -> AECDH-AES256-SHA # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0') - self._controllers: typing.Dict[str, ControllerConn] = {} - self._established: typing.Dict[uuid.UUID, Connection] = {} + self._controllers: dict[str, ControllerConn] = {} + self._established: dict[uuid.UUID, Connection] = {} async def _accept(self, stream): remotename = stream.socket.getpeername() @@ -436,7 +445,7 @@ class Hub: if conn: await conn.send_msg(*msgs) - async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None): + async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None): async with trio.open_nursery() as nursery: for peer_id, conn in self._established.items(): if peer_id == exclude: diff --git a/src/capport/comm/message.py b/src/capport/comm/message.py index 7cd1750..c1eabd2 100644 --- a/src/capport/comm/message.py +++ b/src/capport/comm/message.py @@ -19,7 +19,6 @@ def _make_to_message(oneof_field): return to_message - def _monkey_patch(): g = globals() g['Message'] = message_pb2.Message diff --git a/src/capport/config.py b/src/capport/config.py index eb43a53..3317115 100644 --- a/src/capport/config.py +++ b/src/capport/config.py @@ -28,7 +28,7 @@ class Config: return _cached_config @staticmethod - def load(filename: typing.Optional[str]=None) -> Config: + def load(filename: typing.Optional[str] = None) -> Config: if filename is None: for name in ('capport.yaml', '/etc/capport.yaml'): if os.path.exists(name): diff --git a/src/capport/control/run.py b/src/capport/control/run.py index ffc0871..da042df 100644 --- a/src/capport/control/run.py +++ b/src/capport/control/run.py @@ -1,23 +1,20 @@ from __future__ import annotations -import logging import typing import uuid +import trio + import capport.comm.hub import capport.comm.message import capport.config import capport.database import capport.utils.cli import capport.utils.nft_set -import trio from capport import cptypes from capport.utils.sd_notify import open_sdnotify -_logger = logging.getLogger(__name__) - - class ControlApp(capport.comm.hub.HubApplication): hub: capport.comm.hub.Hub @@ -25,10 +22,18 @@ class ControlApp(capport.comm.hub.HubApplication): super().__init__() self.nft_set = capport.utils.nft_set.NftSet() - async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + async def mac_states_changed( + self, + *, + from_peer_id: uuid.UUID, + pending_updates: capport.database.PendingUpdates, + ) -> None: self.apply_db_entries(pending_updates.changes()) - def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None: + def apply_db_entries( + self, + entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]], + ) -> None: # deploy changes to netfilter set inserts = [] removals = [] diff --git a/src/capport/cptypes.py b/src/capport/cptypes.py index cbab847..953b68b 100644 --- a/src/capport/cptypes.py +++ b/src/capport/cptypes.py @@ -1,6 +1,5 @@ from __future__ import annotations -import capport.utils.zoneinfo import dataclasses import datetime import ipaddress @@ -10,6 +9,8 @@ import typing import quart +import capport.utils.zoneinfo + if typing.TYPE_CHECKING: from .config import Config @@ -93,7 +94,7 @@ class MacPublicState: return now + datetime.timedelta(seconds=self.allowed_remaining) def to_rfc8908(self, config: Config) -> quart.Response: - response: typing.Dict[str, typing.Any] = { + response: dict[str, typing.Any] = { 'user-portal-url': quart.url_for('index', _external=True), } if config.venue_info_url: @@ -104,4 +105,8 @@ class MacPublicState: response['captive'] = False response['seconds-remaining'] = self.allowed_remaining response['can-extend-session'] = True - return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json') + return quart.Response( + json.dumps(response), + headers={'Cache-Control': 'private'}, + content_type='application/captive+json', + ) diff --git a/src/capport/database.py b/src/capport/database.py index 522bebc..3907dbd 100644 --- a/src/capport/database.py +++ b/src/capport/database.py @@ -8,6 +8,7 @@ import struct import typing import google.protobuf.message + import trio import capport.comm.message @@ -74,7 +75,7 @@ class MacEntry: changed = True self.allowed = True # set allow_until to max of both - if new.allow_until: # if not set nothing to change in local data + if new.allow_until: # if not set nothing to change in local data if not self.allow_until or self.allow_until < new.allow_until: changed = True self.allow_until = new.allow_until @@ -89,7 +90,7 @@ class MacEntry: return cptypes.Timestamp(epoch=elc) # returns 0 if not allowed - def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int: + def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int: if not self.allowed or not self.allow_until: return 0 if not now: @@ -97,14 +98,14 @@ class MacEntry: assert self.allow_until return max(self.allow_until.epoch - now.epoch, 0) - def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool: + def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool: if not now: now = cptypes.Timestamp.now() return now.epoch > self.timeout().epoch # might use this to serialize into file - don't need Message variant there -def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: +def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: result: typing.List[capport.comm.message.MacStates] = [] current = capport.comm.message.MacStates() for addr, entry in macs.items(): @@ -118,7 +119,9 @@ def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> ty return result -def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]: +def _serialize_mac_states_as_messages( + macs: dict[cptypes.MacAddress, MacEntry], +) -> typing.List[capport.comm.message.Message]: return [s.to_message() for s in _serialize_mac_states(macs)] @@ -136,8 +139,8 @@ class NotReadyYet(Exception): class Database: - def __init__(self, state_filename: typing.Optional[str]=None): - self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {} + def __init__(self, state_filename: typing.Optional[str] = None): + self._macs: dict[cptypes.MacAddress, MacEntry] = {} self._state_filename = state_filename self._changed_since_last_cleanup = False self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[ @@ -303,7 +306,7 @@ class Database: class PendingUpdates: def __init__(self, database: Database): - self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {} + self._changes: dict[cptypes.MacAddress, MacEntry] = {} self._database = database self._closed = True self._serialized_states: typing.List[capport.comm.message.MacStates] = [] @@ -348,7 +351,7 @@ class PendingUpdates: self._database._macs.pop(addr) self._changes[addr] = old_entry - def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8): + def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8): if self._closed: raise Exception("Can't change closed PendingUpdates") now = cptypes.Timestamp.now() diff --git a/src/capport/stats.py b/src/capport/stats.py index 2831890..fb82db7 100644 --- a/src/capport/stats.py +++ b/src/capport/stats.py @@ -3,16 +3,16 @@ from __future__ import annotations import ipaddress import sys import typing -import time import trio import capport.utils.ipneigh import capport.utils.nft_set + from . import cptypes -def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int]=None, help: typing.Optional[str]=None): +def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = None, help: typing.Optional[str] = None): # no labels in our names for now, always print help and type if help: print(f'# HELP {name} {help}') @@ -47,13 +47,48 @@ async def amain(client_ifname: str): else: total_ipv6 += 1 unique_ipv6.add(mac) - print_metric('capport_allowed_macs', 'gauge', len(captive_allowed_entries), help='Number of allowed client mac addresses') - print_metric('capport_allowed_neigh_macs', 'gauge', len(seen_allowed_entries), help='Number of allowed client mac addresses seen in neighbor cache') - print_metric('capport_unique', 'gauge', len(unique_clients), help='Number of clients (mac addresses) in client network seen in neighbor cache') - print_metric('capport_unique_ipv4', 'gauge', len(unique_ipv4), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache') - print_metric('capport_unique_ipv6', 'gauge', len(unique_ipv6), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache') - print_metric('capport_total_ipv4', 'gauge', total_ipv4, help='Number of IPv4 addresses seen in neighbor cache') - print_metric('capport_total_ipv6', 'gauge', total_ipv6, help='Number of IPv6 addresses seen in neighbor cache') + print_metric( + 'capport_allowed_macs', + 'gauge', + len(captive_allowed_entries), + help='Number of allowed client mac addresses', + ) + print_metric( + 'capport_allowed_neigh_macs', + 'gauge', + len(seen_allowed_entries), + help='Number of allowed client mac addresses seen in neighbor cache', + ) + print_metric( + 'capport_unique', + 'gauge', + len(unique_clients), + help='Number of clients (mac addresses) in client network seen in neighbor cache', + ) + print_metric( + 'capport_unique_ipv4', + 'gauge', + len(unique_ipv4), + help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache', + ) + print_metric( + 'capport_unique_ipv6', + 'gauge', + len(unique_ipv6), + help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache', + ) + print_metric( + 'capport_total_ipv4', + 'gauge', + total_ipv4, + help='Number of IPv4 addresses seen in neighbor cache', + ) + print_metric( + 'capport_total_ipv6', + 'gauge', + total_ipv6, + help='Number of IPv6 addresses seen in neighbor cache', + ) def main(): diff --git a/src/capport/utils/ipneigh.py b/src/capport/utils/ipneigh.py index 797003a..6836f80 100644 --- a/src/capport/utils/ipneigh.py +++ b/src/capport/utils/ipneigh.py @@ -6,10 +6,11 @@ import ipaddress import socket import typing -import pyroute2.iproute.linux -import pyroute2.netlink.exceptions -import pyroute2.netlink.rtnl -import pyroute2.netlink.rtnl.ndmsg +import pyroute2.iproute.linux # type: ignore +import pyroute2.netlink.exceptions # type: ignore +import pyroute2.netlink.rtnl # type: ignore +import pyroute2.netlink.rtnl.ndmsg # type: ignore + from capport import cptypes @@ -27,8 +28,8 @@ class NeighborController: self, address: cptypes.IPAddress, *, - index: int=0, # interface index - flags: int=0, + index: int = 0, # interface index + flags: int = 0, ) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]: if not index: route = await self.get_route(address) @@ -46,8 +47,8 @@ class NeighborController: self, address: cptypes.IPAddress, *, - index: int=0, # interface index - flags: int=0, + index: int = 0, # interface index + flags: int = 0, ) -> typing.Optional[cptypes.MacAddress]: neigh = await self.get_neighbor(address, index=index, flags=flags) if neigh is None: @@ -66,7 +67,10 @@ class NeighborController: return None raise - async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: + async def dump_neighbors( + self, + interface: str, + ) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: ifindex = socket.if_nametoindex(interface) unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) diff --git a/src/capport/utils/nft_set.py b/src/capport/utils/nft_set.py index 0503ad4..7731536 100644 --- a/src/capport/utils/nft_set.py +++ b/src/capport/utils/nft_set.py @@ -2,9 +2,10 @@ from __future__ import annotations import typing -import pyroute2.netlink +import pyroute2.netlink # type: ignore +from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore + from capport import cptypes -from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket from .nft_socket import NFTSocket @@ -24,8 +25,11 @@ class NftSet: self._socket.bind() @staticmethod - def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem: - attrs: typing.Dict[str, typing.Any] = { + def _set_elem( + mac: cptypes.MacAddress, + timeout: typing.Optional[typing.Union[int, float]] = None, + ) -> _nftsocket.nft_set_elem_list_msg.set_elem: + attrs: dict[str, typing.Any] = { 'NFTA_SET_ELEM_KEY': dict( NFTA_DATA_VALUE=mac.raw, ), @@ -34,7 +38,10 @@ class NftSet: attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout) return attrs - def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None: + def _bulk_insert( + self, + entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]], + ) -> None: ser_entries = [ self._set_elem(mac) for mac, _timeout in entries @@ -69,7 +76,7 @@ class NftSet: # now create entries with new timeout value tx.put( _nftsocket.NFT_MSG_NEWSETELEM, - pyroute2.netlink.NLM_F_CREATE|pyroute2.netlink.NLM_F_EXCL, + pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_SET_TABLE='captive_mark', @@ -173,7 +180,7 @@ class NftSet: attrs=dict( NFTA_SET_TABLE='captive_mark', NFTA_SET_NAME='allowed', - NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT + NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft NFTA_SET_KEY_LEN=6, # length of key: mac address NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py index b5e7092..9cf2670 100644 --- a/src/capport/utils/nft_socket.py +++ b/src/capport/utils/nft_socket.py @@ -2,13 +2,12 @@ from __future__ import annotations import contextlib import typing -import threading -from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket -import pyroute2.netlink -import pyroute2.netlink.nlsocket -from pyroute2.netlink.nfnetlink import nfgen_msg -from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES +import pyroute2.netlink # type: ignore +import pyroute2.netlink.nlsocket # type: ignore +from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore +from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore +from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" @@ -45,10 +44,12 @@ def _monkey_patch_pyroute2(): overwrite_methods(subcls) overwrite_methods(pyroute2.netlink.nlmsg_base) + + _monkey_patch_pyroute2() -def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase: +def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: msg = msg_class() for key, value in header.items(): msg['header'][key] = value @@ -66,7 +67,9 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: # support passing nested attributes as dicts of subattributes (or lists of those) if prime['nla_array']: value = [ - _build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem + _build(nla_class, attrs=elem) + if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) + else elem for elem in value ] elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict): @@ -76,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): - policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy + policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy def __init__(self) -> None: super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) @@ -98,15 +101,15 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): # abort does nothing if commit went through tx.abort() - def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: with self.begin() as tx: tx.put(msg_type, msg_flags, attrs=attrs, **fields) - def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_flags |= pyroute2.netlink.NLM_F_DUMP return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) - def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_flags |= pyroute2.netlink.NLM_F_REQUEST @@ -207,7 +210,7 @@ class NFTTransaction: self._data += self._final_msg.data self._final_msg = msg - def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set! diff --git a/src/capport/utils/sd_notify.py b/src/capport/utils/sd_notify.py index 66c52d9..65a5aef 100644 --- a/src/capport/utils/sd_notify.py +++ b/src/capport/utils/sd_notify.py @@ -1,11 +1,12 @@ from __future__ import annotations -import typing -import trio -import trio.socket import contextlib import os import socket +import typing + +import trio +import trio.socket def _check_watchdog_pid() -> bool: @@ -63,7 +64,7 @@ class SdNotify: if not self.is_connected(): return dgram = '\n'.join(msg).encode('utf-8') + assert self._ns, "not connected" # checked above sent = await self._ns.send(dgram) if sent != len(dgram): raise OSError("Sent incomplete datagram to NOTIFY_SOCKET") -