2
0

flake8 linting

This commit is contained in:
Stefan Bühler 2023-01-12 13:16:51 +01:00
parent f9c9b98868
commit 9325950f51
21 changed files with 220 additions and 126 deletions

11
.flake8 Normal file
View File

@ -0,0 +1,11 @@
[flake8]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E402 module level import not at top of file [ usually on purpose. might use individual overrides instead? ]
# E701 multiple statements on one line [ still quite readable in short forms ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 line break before binary operator [ gotta pick one way ]
extend-ignore = E266,E402,E701,E713,E714,W503
max-line-length = 120
exclude = *_pb2.py
application-import-names = capport

View File

@ -1,11 +0,0 @@
[pycodestyle]
# E241 multiple spaces after ':' [ want to align stuff ]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E501 line too long [ temporary? can't disable it in certain places.. ]
# E701 multiple statements on one line (colon) [ perfectly readable ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ]
ignore = E241,E266,E501,E701,E713,E714,W503
max-line-length = 120
exclude = 00*.py,generated.py

View File

@ -1,20 +0,0 @@
[MESSAGES CONTROL]
disable=logging-fstring-interpolation
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
[DESIGN]
# Maximum number of locals for function / method body.
max-locals=20
# Minimum number of public methods for a class (see R0903).
min-public-methods=0

View File

@ -12,9 +12,8 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
exit 1
fi
if [ ! -x ./venv/bin/pylint ]; then
# need current pylint to deal with more recent python features
./venv/bin/pip install pylint
if [ ! -x ./venv/bin/flake8 ]; then
./venv/bin/pip install flake8 flake8-import-order
fi
./venv/bin/pylint src
./venv/bin/flake8 src

20
mypy
View File

@ -13,12 +13,30 @@ if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
fi
if [ ! -x ./venv/bin/mypy ]; then
./venv/bin/pip install mypy
./venv/bin/pip install mypy trio-typing[mypy] types-PyYAML types-aiofiles types-colorama types-cryptography types-protobuf types-toml
fi
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
if [ ! -d "${site_pkgs}/trio_typing" ]; then
./venv/bin/pip install trio-typing[mypy]
fi
if [ ! -d "${site_pkgs}/yaml-stubs" ]; then
./venv/bin/pip install types-PyYAML
fi
if [ ! -d "${site_pkgs}/aiofiles-stubs" ]; then
./venv/bin/pip install types-aiofiles
fi
if [ ! -d "${site_pkgs}/colorama-stubs" ]; then
./venv/bin/pip install types-colorama
fi
if [ ! -d "${site_pkgs}/cryptography-stubs" ]; then
./venv/bin/pip install types-cryptography
fi
if [ ! -d "${site_pkgs}/google-stubs" ]; then
./venv/bin/pip install types-protobuf
fi
if [ ! -d "${site_pkgs}/toml-stubs" ]; then
./venv/bin/pip install types-toml
fi
./venv/bin/mypy --install-types src

View File

@ -4,3 +4,11 @@ requires = [
"wheel"
]
build-backend = "setuptools.build_meta"
[tool.mypy]
python_version = "3.9"
# warn_return_any = true
warn_unused_configs = true
exclude = [
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
]

View File

@ -4,12 +4,15 @@ import os
import os.path
import typing
import jinja2
import quart.templating
import quart_trio
import capport.comm.hub
import capport.config
import capport.utils.ipneigh
import jinja2
import quart.templating
import quart_trio
class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):

View File

@ -4,13 +4,14 @@ import ipaddress
import typing
import quart
from werkzeug.http import parse_list_header
import werkzeug
from werkzeug.http import parse_list_header
from .app import app
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str]=()) -> typing.Optional[str]:
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]:
if not value_list:
return None
values = parse_list_header(value_list)
@ -21,6 +22,8 @@ def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequenc
def local_proxy_fix(request: quart.Request):
if not request.remote_addr:
return
try:
addr = ipaddress.ip_address(request.remote_addr)
except ValueError:

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import logging
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.utils.cli
import capport.utils.ipneigh
import trio
import capport.comm.hub
import capport.comm.message
import capport.database
import capport.utils.cli
import capport.utils.ipneigh
from capport.utils.sd_notify import open_sdnotify
from .app import app
@ -19,7 +19,12 @@ _logger = logging.getLogger(__name__)
class ApiHubApp(capport.comm.hub.HubApplication):
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
# TODO: support websocket notification updates to clients?
pass

View File

@ -4,13 +4,15 @@ import ipaddress
import logging
import typing
import quart
import trio
import capport.comm.hub
import capport.comm.message
import capport.database
import capport.utils.cli
import capport.utils.ipneigh
import quart
import trio
from capport import cptypes
from .app import app
@ -20,22 +22,27 @@ _logger = logging.getLogger(__name__)
def get_client_ip() -> cptypes.IPAddress:
remote_addr = quart.request.remote_addr
if not remote_addr:
quart.abort(500, 'Missing client address')
try:
addr = ipaddress.ip_address(quart.request.remote_addr)
addr = ipaddress.ip_address(remote_addr)
except ValueError as e:
_logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}')
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address')
return addr
async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]:
async def get_client_mac_if_present(
address: typing.Optional[cptypes.IPAddress] = None,
) -> typing.Optional[cptypes.MacAddress]:
assert app.my_nc # for mypy
if not address:
address = get_client_ip()
return await app.my_nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress:
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")
@ -109,7 +116,7 @@ def check_self_origin():
@app.route('/', methods=['GET'])
async def index(missing_accept: bool=False):
async def index(missing_accept: bool = False):
state = await user_lookup()
if not state.mac:
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)

View File

@ -10,10 +10,12 @@ import struct
import typing
import uuid
import capport.database
import capport.comm.message
import trio
import capport.comm.message
import capport.database
if typing.TYPE_CHECKING:
from ..config import Config
@ -57,7 +59,9 @@ class Channel:
peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message")
auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
auth_succ = \
(peer_hello.authentication ==
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
@ -159,7 +163,7 @@ class Connection:
_logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed sending: {e!r}")
except Exception as e:
except Exception:
_logger.exception(f"{self._channel}: failed sending")
finally:
cancel_scope.cancel()
@ -224,7 +228,7 @@ class Connection:
with trio.fail_after(5):
peer = await channel.do_handshake()
except trio.TooSlowError:
_logger.warning(f"Handshake timed out")
_logger.warning("Handshake timed out")
return
conn = Connection(hub, channel, peer)
await conn._run()
@ -278,7 +282,12 @@ class HubApplication:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
@ -303,8 +312,8 @@ class Hub:
# -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._controllers: typing.Dict[str, ControllerConn] = {}
self._established: typing.Dict[uuid.UUID, Connection] = {}
self._controllers: dict[str, ControllerConn] = {}
self._established: dict[uuid.UUID, Connection] = {}
async def _accept(self, stream):
remotename = stream.socket.getpeername()
@ -436,7 +445,7 @@ class Hub:
if conn:
await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None):
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None):
async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items():
if peer_id == exclude:

View File

@ -19,7 +19,6 @@ def _make_to_message(oneof_field):
return to_message
def _monkey_patch():
g = globals()
g['Message'] = message_pb2.Message

View File

@ -28,7 +28,7 @@ class Config:
return _cached_config
@staticmethod
def load(filename: typing.Optional[str]=None) -> Config:
def load(filename: typing.Optional[str] = None) -> Config:
if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name):

View File

@ -1,23 +1,20 @@
from __future__ import annotations
import logging
import typing
import uuid
import trio
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.database
import capport.utils.cli
import capport.utils.nft_set
import trio
from capport import cptypes
from capport.utils.sd_notify import open_sdnotify
_logger = logging.getLogger(__name__)
class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub
@ -25,10 +22,18 @@ class ControlApp(capport.comm.hub.HubApplication):
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
self.apply_db_entries(pending_updates.changes())
def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None:
def apply_db_entries(
self,
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None:
# deploy changes to netfilter set
inserts = []
removals = []

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import capport.utils.zoneinfo
import dataclasses
import datetime
import ipaddress
@ -10,6 +9,8 @@ import typing
import quart
import capport.utils.zoneinfo
if typing.TYPE_CHECKING:
from .config import Config
@ -93,7 +94,7 @@ class MacPublicState:
return now + datetime.timedelta(seconds=self.allowed_remaining)
def to_rfc8908(self, config: Config) -> quart.Response:
response: typing.Dict[str, typing.Any] = {
response: dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True),
}
if config.venue_info_url:
@ -104,4 +105,8 @@ class MacPublicState:
response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True
return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json')
return quart.Response(
json.dumps(response),
headers={'Cache-Control': 'private'},
content_type='application/captive+json',
)

View File

@ -8,6 +8,7 @@ import struct
import typing
import google.protobuf.message
import trio
import capport.comm.message
@ -89,7 +90,7 @@ class MacEntry:
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
@ -97,14 +98,14 @@ class MacEntry:
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
@ -118,7 +119,9 @@ def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> ty
return result
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
@ -136,8 +139,8 @@ class NotReadyYet(Exception):
class Database:
def __init__(self, state_filename: typing.Optional[str]=None):
self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {}
def __init__(self, state_filename: typing.Optional[str] = None):
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
@ -303,7 +306,7 @@ class Database:
class PendingUpdates:
def __init__(self, database: Database):
self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {}
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
@ -348,7 +351,7 @@ class PendingUpdates:
self._database._macs.pop(addr)
self._changes[addr] = old_entry
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8):
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()

View File

@ -3,16 +3,16 @@ from __future__ import annotations
import ipaddress
import sys
import typing
import time
import trio
import capport.utils.ipneigh
import capport.utils.nft_set
from . import cptypes
def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int]=None, help: typing.Optional[str]=None):
def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = None, help: typing.Optional[str] = None):
# no labels in our names for now, always print help and type
if help:
print(f'# HELP {name} {help}')
@ -47,13 +47,48 @@ async def amain(client_ifname: str):
else:
total_ipv6 += 1
unique_ipv6.add(mac)
print_metric('capport_allowed_macs', 'gauge', len(captive_allowed_entries), help='Number of allowed client mac addresses')
print_metric('capport_allowed_neigh_macs', 'gauge', len(seen_allowed_entries), help='Number of allowed client mac addresses seen in neighbor cache')
print_metric('capport_unique', 'gauge', len(unique_clients), help='Number of clients (mac addresses) in client network seen in neighbor cache')
print_metric('capport_unique_ipv4', 'gauge', len(unique_ipv4), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache')
print_metric('capport_unique_ipv6', 'gauge', len(unique_ipv6), help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache')
print_metric('capport_total_ipv4', 'gauge', total_ipv4, help='Number of IPv4 addresses seen in neighbor cache')
print_metric('capport_total_ipv6', 'gauge', total_ipv6, help='Number of IPv6 addresses seen in neighbor cache')
print_metric(
'capport_allowed_macs',
'gauge',
len(captive_allowed_entries),
help='Number of allowed client mac addresses',
)
print_metric(
'capport_allowed_neigh_macs',
'gauge',
len(seen_allowed_entries),
help='Number of allowed client mac addresses seen in neighbor cache',
)
print_metric(
'capport_unique',
'gauge',
len(unique_clients),
help='Number of clients (mac addresses) in client network seen in neighbor cache',
)
print_metric(
'capport_unique_ipv4',
'gauge',
len(unique_ipv4),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
)
print_metric(
'capport_unique_ipv6',
'gauge',
len(unique_ipv6),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
)
print_metric(
'capport_total_ipv4',
'gauge',
total_ipv4,
help='Number of IPv4 addresses seen in neighbor cache',
)
print_metric(
'capport_total_ipv6',
'gauge',
total_ipv6,
help='Number of IPv6 addresses seen in neighbor cache',
)
def main():

View File

@ -6,10 +6,11 @@ import ipaddress
import socket
import typing
import pyroute2.iproute.linux
import pyroute2.netlink.exceptions
import pyroute2.netlink.rtnl
import pyroute2.netlink.rtnl.ndmsg
import pyroute2.iproute.linux # type: ignore
import pyroute2.netlink.exceptions # type: ignore
import pyroute2.netlink.rtnl # type: ignore
import pyroute2.netlink.rtnl.ndmsg # type: ignore
from capport import cptypes
@ -27,8 +28,8 @@ class NeighborController:
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
if not index:
route = await self.get_route(address)
@ -46,8 +47,8 @@ class NeighborController:
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[cptypes.MacAddress]:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
@ -66,7 +67,10 @@ class NeighborController:
return None
raise
async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
async def dump_neighbors(
self,
interface: str,
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import typing
import pyroute2.netlink
import pyroute2.netlink # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
from capport import cptypes
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket
@ -24,8 +25,11 @@ class NftSet:
self._socket.bind()
@staticmethod
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: typing.Dict[str, typing.Any] = {
def _set_elem(
mac: cptypes.MacAddress,
timeout: typing.Optional[typing.Union[int, float]] = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw,
),
@ -34,7 +38,10 @@ class NftSet:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
def _bulk_insert(
self,
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
@ -69,7 +76,7 @@ class NftSet:
# now create entries with new timeout value
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pyroute2.netlink.NLM_F_CREATE|pyroute2.netlink.NLM_F_EXCL,
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',

View File

@ -2,13 +2,12 @@ from __future__ import annotations
import contextlib
import typing
import threading
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket
import pyroute2.netlink
import pyroute2.netlink.nlsocket
from pyroute2.netlink.nfnetlink import nfgen_msg
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
import pyroute2.netlink # type: ignore
import pyroute2.netlink.nlsocket # type: ignore
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
@ -45,10 +44,12 @@ def _monkey_patch_pyroute2():
overwrite_methods(subcls)
overwrite_methods(pyroute2.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
@ -66,7 +67,9 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header:
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem
_build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
else elem
for elem in value
]
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
@ -76,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header:
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
@ -98,15 +101,15 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_flags |= pyroute2.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
@ -207,7 +210,7 @@ class NFTTransaction:
self._data += self._final_msg.data
self._final_msg = msg
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!

View File

@ -1,11 +1,12 @@
from __future__ import annotations
import typing
import trio
import trio.socket
import contextlib
import os
import socket
import typing
import trio
import trio.socket
def _check_watchdog_pid() -> bool:
@ -63,7 +64,7 @@ class SdNotify:
if not self.is_connected():
return
dgram = '\n'.join(msg).encode('utf-8')
assert self._ns, "not connected" # checked above
sent = await self._ns.send(dgram)
if sent != len(dgram):
raise OSError("Sent incomplete datagram to NOTIFY_SOCKET")