2
0

flake8 linting

This commit is contained in:
Stefan Bühler 2023-01-12 13:16:51 +01:00
parent f9c9b98868
commit 9325950f51
21 changed files with 220 additions and 126 deletions

11
.flake8 Normal file
View File

@ -0,0 +1,11 @@
[flake8]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E402 module level import not at top of file [ usually on purpose. might use individual overrides instead? ]
# E701 multiple statements on one line [ still quite readable in short forms ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 line break before binary operator [ gotta pick one way ]
extend-ignore = E266,E402,E701,E713,E714,W503
max-line-length = 120
exclude = *_pb2.py
application-import-names = capport

View File

@ -1,11 +0,0 @@
[pycodestyle]
# E241 multiple spaces after ':' [ want to align stuff ]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E501 line too long [ temporary? can't disable it in certain places.. ]
# E701 multiple statements on one line (colon) [ perfectly readable ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ]
ignore = E241,E266,E501,E701,E713,E714,W503
max-line-length = 120
exclude = 00*.py,generated.py

View File

@ -1,20 +0,0 @@
[MESSAGES CONTROL]
disable=logging-fstring-interpolation
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
[DESIGN]
# Maximum number of locals for function / method body.
max-locals=20
# Minimum number of public methods for a class (see R0903).
min-public-methods=0

View File

@ -12,9 +12,8 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
exit 1 exit 1
fi fi
if [ ! -x ./venv/bin/pylint ]; then if [ ! -x ./venv/bin/flake8 ]; then
# need current pylint to deal with more recent python features ./venv/bin/pip install flake8 flake8-import-order
./venv/bin/pip install pylint
fi fi
./venv/bin/pylint src ./venv/bin/flake8 src

20
mypy
View File

@ -13,12 +13,30 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
fi fi
if [ ! -x ./venv/bin/mypy ]; then if [ ! -x ./venv/bin/mypy ]; then
./venv/bin/pip install mypy ./venv/bin/pip install mypy trio-typing[mypy] types-PyYAML types-aiofiles types-colorama types-cryptography types-protobuf types-toml
fi fi
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])') site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
if [ ! -d "${site_pkgs}/trio_typing" ]; then if [ ! -d "${site_pkgs}/trio_typing" ]; then
./venv/bin/pip install trio-typing[mypy] ./venv/bin/pip install trio-typing[mypy]
fi fi
if [ ! -d "${site_pkgs}/yaml-stubs" ]; then
./venv/bin/pip install types-PyYAML
fi
if [ ! -d "${site_pkgs}/aiofiles-stubs" ]; then
./venv/bin/pip install types-aiofiles
fi
if [ ! -d "${site_pkgs}/colorama-stubs" ]; then
./venv/bin/pip install types-colorama
fi
if [ ! -d "${site_pkgs}/cryptography-stubs" ]; then
./venv/bin/pip install types-cryptography
fi
if [ ! -d "${site_pkgs}/google-stubs" ]; then
./venv/bin/pip install types-protobuf
fi
if [ ! -d "${site_pkgs}/toml-stubs" ]; then
./venv/bin/pip install types-toml
fi
./venv/bin/mypy --install-types src ./venv/bin/mypy --install-types src

View File

@ -4,3 +4,11 @@ requires = [
"wheel" "wheel"
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.mypy]
python_version = "3.9"
# warn_return_any = true
warn_unused_configs = true
exclude = [
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
]

View File

@ -4,12 +4,15 @@ import os
import os.path import os.path
import typing import typing
import jinja2
import quart.templating
import quart_trio
import capport.comm.hub import capport.comm.hub
import capport.config import capport.config
import capport.utils.ipneigh import capport.utils.ipneigh
import jinja2
import quart.templating
import quart_trio
class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader): class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):

View File

@ -4,8 +4,9 @@ import ipaddress
import typing import typing
import quart import quart
from werkzeug.http import parse_list_header
import werkzeug import werkzeug
from werkzeug.http import parse_list_header
from .app import app from .app import app
@ -21,6 +22,8 @@ def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequenc
def local_proxy_fix(request: quart.Request): def local_proxy_fix(request: quart.Request):
if not request.remote_addr:
return
try: try:
addr = ipaddress.ip_address(request.remote_addr) addr = ipaddress.ip_address(request.remote_addr)
except ValueError: except ValueError:

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import logging import logging
import uuid import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.utils.cli
import capport.utils.ipneigh
import trio import trio
import capport.comm.hub
import capport.comm.message
import capport.database
import capport.utils.cli
import capport.utils.ipneigh
from capport.utils.sd_notify import open_sdnotify from capport.utils.sd_notify import open_sdnotify
from .app import app from .app import app
@ -19,7 +19,12 @@ _logger = logging.getLogger(__name__)
class ApiHubApp(capport.comm.hub.HubApplication): class ApiHubApp(capport.comm.hub.HubApplication):
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
# TODO: support websocket notification updates to clients? # TODO: support websocket notification updates to clients?
pass pass

View File

@ -4,13 +4,15 @@ import ipaddress
import logging import logging
import typing import typing
import quart
import trio
import capport.comm.hub import capport.comm.hub
import capport.comm.message import capport.comm.message
import capport.database import capport.database
import capport.utils.cli import capport.utils.cli
import capport.utils.ipneigh import capport.utils.ipneigh
import quart
import trio
from capport import cptypes from capport import cptypes
from .app import app from .app import app
@ -20,15 +22,20 @@ _logger = logging.getLogger(__name__)
def get_client_ip() -> cptypes.IPAddress: def get_client_ip() -> cptypes.IPAddress:
remote_addr = quart.request.remote_addr
if not remote_addr:
quart.abort(500, 'Missing client address')
try: try:
addr = ipaddress.ip_address(quart.request.remote_addr) addr = ipaddress.ip_address(remote_addr)
except ValueError as e: except ValueError as e:
_logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') _logger.warning(f'Invalid client address {remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address') quart.abort(500, 'Invalid client address')
return addr return addr
async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]: async def get_client_mac_if_present(
address: typing.Optional[cptypes.IPAddress] = None,
) -> typing.Optional[cptypes.MacAddress]:
assert app.my_nc # for mypy assert app.my_nc # for mypy
if not address: if not address:
address = get_client_ip() address = get_client_ip()

View File

@ -10,10 +10,12 @@ import struct
import typing import typing
import uuid import uuid
import capport.database
import capport.comm.message
import trio import trio
import capport.comm.message
import capport.database
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from ..config import Config from ..config import Config
@ -57,7 +59,9 @@ class Channel:
peer_hello = (await self.recv_msg()).to_variant() peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello): if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message") raise HubConnectionReadError("Expected Hello as first message")
auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)) auth_succ = \
(peer_hello.authentication ==
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant() peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
@ -159,7 +163,7 @@ class Connection:
_logger.warning(f"{self._channel}: send timed out") _logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e: except ConnectionError as e:
_logger.warning(f"{self._channel}: failed sending: {e!r}") _logger.warning(f"{self._channel}: failed sending: {e!r}")
except Exception as e: except Exception:
_logger.exception(f"{self._channel}: failed sending") _logger.exception(f"{self._channel}: failed sending")
finally: finally:
cancel_scope.cancel() cancel_scope.cancel()
@ -224,7 +228,7 @@ class Connection:
with trio.fail_after(5): with trio.fail_after(5):
peer = await channel.do_handshake() peer = await channel.do_handshake()
except trio.TooSlowError: except trio.TooSlowError:
_logger.warning(f"Handshake timed out") _logger.warning("Handshake timed out")
return return
conn = Connection(hub, channel, peer) conn = Connection(hub, channel, peer)
await conn._run() await conn._run()
@ -278,7 +282,12 @@ class HubApplication:
if _logger.isEnabledFor(logging.DEBUG): if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}") _logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
if _logger.isEnabledFor(logging.DEBUG): if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}") _logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
@ -303,8 +312,8 @@ class Hub:
# -> AECDH-AES256-SHA # -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0') self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._controllers: typing.Dict[str, ControllerConn] = {} self._controllers: dict[str, ControllerConn] = {}
self._established: typing.Dict[uuid.UUID, Connection] = {} self._established: dict[uuid.UUID, Connection] = {}
async def _accept(self, stream): async def _accept(self, stream):
remotename = stream.socket.getpeername() remotename = stream.socket.getpeername()

View File

@ -19,7 +19,6 @@ def _make_to_message(oneof_field):
return to_message return to_message
def _monkey_patch(): def _monkey_patch():
g = globals() g = globals()
g['Message'] = message_pb2.Message g['Message'] = message_pb2.Message

View File

@ -1,23 +1,20 @@
from __future__ import annotations from __future__ import annotations
import logging
import typing import typing
import uuid import uuid
import trio
import capport.comm.hub import capport.comm.hub
import capport.comm.message import capport.comm.message
import capport.config import capport.config
import capport.database import capport.database
import capport.utils.cli import capport.utils.cli
import capport.utils.nft_set import capport.utils.nft_set
import trio
from capport import cptypes from capport import cptypes
from capport.utils.sd_notify import open_sdnotify from capport.utils.sd_notify import open_sdnotify
_logger = logging.getLogger(__name__)
class ControlApp(capport.comm.hub.HubApplication): class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub hub: capport.comm.hub.Hub
@ -25,10 +22,18 @@ class ControlApp(capport.comm.hub.HubApplication):
super().__init__() super().__init__()
self.nft_set = capport.utils.nft_set.NftSet() self.nft_set = capport.utils.nft_set.NftSet()
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
self.apply_db_entries(pending_updates.changes()) self.apply_db_entries(pending_updates.changes())
def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None: def apply_db_entries(
self,
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None:
# deploy changes to netfilter set # deploy changes to netfilter set
inserts = [] inserts = []
removals = [] removals = []

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import capport.utils.zoneinfo
import dataclasses import dataclasses
import datetime import datetime
import ipaddress import ipaddress
@ -10,6 +9,8 @@ import typing
import quart import quart
import capport.utils.zoneinfo
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .config import Config from .config import Config
@ -93,7 +94,7 @@ class MacPublicState:
return now + datetime.timedelta(seconds=self.allowed_remaining) return now + datetime.timedelta(seconds=self.allowed_remaining)
def to_rfc8908(self, config: Config) -> quart.Response: def to_rfc8908(self, config: Config) -> quart.Response:
response: typing.Dict[str, typing.Any] = { response: dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True), 'user-portal-url': quart.url_for('index', _external=True),
} }
if config.venue_info_url: if config.venue_info_url:
@ -104,4 +105,8 @@ class MacPublicState:
response['captive'] = False response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True response['can-extend-session'] = True
return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json') return quart.Response(
json.dumps(response),
headers={'Cache-Control': 'private'},
content_type='application/captive+json',
)

View File

@ -8,6 +8,7 @@ import struct
import typing import typing
import google.protobuf.message import google.protobuf.message
import trio import trio
import capport.comm.message import capport.comm.message
@ -104,7 +105,7 @@ class MacEntry:
# might use this to serialize into file - don't need Message variant there # might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = [] result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates() current = capport.comm.message.MacStates()
for addr, entry in macs.items(): for addr, entry in macs.items():
@ -118,7 +119,9 @@ def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> ty
return result return result
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]: def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)] return [s.to_message() for s in _serialize_mac_states(macs)]
@ -137,7 +140,7 @@ class NotReadyYet(Exception):
class Database: class Database:
def __init__(self, state_filename: typing.Optional[str] = None): def __init__(self, state_filename: typing.Optional[str] = None):
self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {} self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename self._state_filename = state_filename
self._changed_since_last_cleanup = False self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[ self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
@ -303,7 +306,7 @@ class Database:
class PendingUpdates: class PendingUpdates:
def __init__(self, database: Database): def __init__(self, database: Database):
self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {} self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database self._database = database
self._closed = True self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = [] self._serialized_states: typing.List[capport.comm.message.MacStates] = []

View File

@ -3,12 +3,12 @@ from __future__ import annotations
import ipaddress import ipaddress
import sys import sys
import typing import typing
import time
import trio import trio
import capport.utils.ipneigh import capport.utils.ipneigh
import capport.utils.nft_set import capport.utils.nft_set
from . import cptypes from . import cptypes
@ -47,13 +47,48 @@ async def amain(client_ifname: str):
else: else:
total_ipv6 += 1 total_ipv6 += 1
unique_ipv6.add(mac) unique_ipv6.add(mac)
print_metric('capport_allowed_macs', 'gauge', len(captive_allowed_entries), help='Number of allowed client mac addresses') print_metric(
print_metric('capport_allowed_neigh_macs', 'gauge', len(seen_allowed_entries), help='Number of allowed client mac addresses seen in neighbor cache') 'capport_allowed_macs',
print_metric('capport_unique', 'gauge', len(unique_clients), help='Number of clients (mac addresses) in client network seen in neighbor cache') 'gauge',
print_metric('capport_unique_ipv4', 'gauge', len(unique_ipv4), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache') len(captive_allowed_entries),
print_metric('capport_unique_ipv6', 'gauge', len(unique_ipv6), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache') help='Number of allowed client mac addresses',
print_metric('capport_total_ipv4', 'gauge', total_ipv4, help='Number of IPv4 addresses seen in neighbor cache') )
print_metric('capport_total_ipv6', 'gauge', total_ipv6, help='Number of IPv6 addresses seen in neighbor cache') print_metric(
'capport_allowed_neigh_macs',
'gauge',
len(seen_allowed_entries),
help='Number of allowed client mac addresses seen in neighbor cache',
)
print_metric(
'capport_unique',
'gauge',
len(unique_clients),
help='Number of clients (mac addresses) in client network seen in neighbor cache',
)
print_metric(
'capport_unique_ipv4',
'gauge',
len(unique_ipv4),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
)
print_metric(
'capport_unique_ipv6',
'gauge',
len(unique_ipv6),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
)
print_metric(
'capport_total_ipv4',
'gauge',
total_ipv4,
help='Number of IPv4 addresses seen in neighbor cache',
)
print_metric(
'capport_total_ipv6',
'gauge',
total_ipv6,
help='Number of IPv6 addresses seen in neighbor cache',
)
def main(): def main():

View File

@ -6,10 +6,11 @@ import ipaddress
import socket import socket
import typing import typing
import pyroute2.iproute.linux import pyroute2.iproute.linux # type: ignore
import pyroute2.netlink.exceptions import pyroute2.netlink.exceptions # type: ignore
import pyroute2.netlink.rtnl import pyroute2.netlink.rtnl # type: ignore
import pyroute2.netlink.rtnl.ndmsg import pyroute2.netlink.rtnl.ndmsg # type: ignore
from capport import cptypes from capport import cptypes
@ -66,7 +67,10 @@ class NeighborController:
return None return None
raise raise
async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: async def dump_neighbors(
self,
interface: str,
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface) ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import typing import typing
import pyroute2.netlink import pyroute2.netlink # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
from capport import cptypes from capport import cptypes
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket from .nft_socket import NFTSocket
@ -24,8 +25,11 @@ class NftSet:
self._socket.bind() self._socket.bind()
@staticmethod @staticmethod
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem: def _set_elem(
attrs: typing.Dict[str, typing.Any] = { mac: cptypes.MacAddress,
timeout: typing.Optional[typing.Union[int, float]] = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict( 'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw, NFTA_DATA_VALUE=mac.raw,
), ),
@ -34,7 +38,10 @@ class NftSet:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout) attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs return attrs
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None: def _bulk_insert(
self,
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
) -> None:
ser_entries = [ ser_entries = [
self._set_elem(mac) self._set_elem(mac)
for mac, _timeout in entries for mac, _timeout in entries

View File

@ -2,13 +2,12 @@ from __future__ import annotations
import contextlib import contextlib
import typing import typing
import threading
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket import pyroute2.netlink # type: ignore
import pyroute2.netlink import pyroute2.netlink.nlsocket # type: ignore
import pyroute2.netlink.nlsocket from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
from pyroute2.netlink.nfnetlink import nfgen_msg from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
@ -45,10 +44,12 @@ def _monkey_patch_pyroute2():
overwrite_methods(subcls) overwrite_methods(subcls)
overwrite_methods(pyroute2.netlink.nlmsg_base) overwrite_methods(pyroute2.netlink.nlmsg_base)
_monkey_patch_pyroute2() _monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase: def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class() msg = msg_class()
for key, value in header.items(): for key, value in header.items():
msg['header'][key] = value msg['header'][key] = value
@ -66,7 +67,9 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header:
# support passing nested attributes as dicts of subattributes (or lists of those) # support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']: if prime['nla_array']:
value = [ value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem _build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
else elem
for elem in value for elem in value
] ]
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict): elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
@ -76,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header:
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
@ -98,15 +101,15 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
# abort does nothing if commit went through # abort does nothing if commit went through
tx.abort() tx.abort()
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
with self.begin() as tx: with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields) tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_flags |= pyroute2.netlink.NLM_F_DUMP msg_flags |= pyroute2.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST
@ -207,7 +210,7 @@ class NFTTransaction:
self._data += self._final_msg.data self._data += self._final_msg.data
self._final_msg = msg self._final_msg = msg
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set! msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!

View File

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
import typing
import trio
import trio.socket
import contextlib import contextlib
import os import os
import socket import socket
import typing
import trio
import trio.socket
def _check_watchdog_pid() -> bool: def _check_watchdog_pid() -> bool:
@ -63,7 +64,7 @@ class SdNotify:
if not self.is_connected(): if not self.is_connected():
return return
dgram = '\n'.join(msg).encode('utf-8') dgram = '\n'.join(msg).encode('utf-8')
assert self._ns, "not connected" # checked above
sent = await self._ns.send(dgram) sent = await self._ns.send(dgram)
if sent != len(dgram): if sent != len(dgram):
raise OSError("Sent incomplete datagram to NOTIFY_SOCKET") raise OSError("Sent incomplete datagram to NOTIFY_SOCKET")