2
0

replace some typing.* constructs with native ones

This commit is contained in:
Stefan Bühler 2023-11-15 10:35:38 +01:00
parent 29d1a3226a
commit ced589f28a
16 changed files with 77 additions and 85 deletions

View File

@ -42,10 +42,10 @@ class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):
class MyQuartApp(quart_trio.QuartTrio): class MyQuartApp(quart_trio.QuartTrio):
my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None my_nc: capport.utils.ipneigh.NeighborController | None = None
my_hub: typing.Optional[capport.comm.hub.Hub] = None my_hub: capport.comm.hub.Hub | None = None
my_config: capport.config.Config my_config: capport.config.Config
custom_loader: typing.Optional[jinja2.FileSystemLoader] = None custom_loader: jinja2.FileSystemLoader | None = None
def __init__(self, import_name: str, **kwargs) -> None: def __init__(self, import_name: str, **kwargs) -> None:
self.my_config = capport.config.Config.load_default_once() self.my_config = capport.config.Config.load_default_once()

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os.path import os.path
import re import re
import typing
import quart import quart
@ -11,7 +10,7 @@ from .app import app
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$') _VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
def parse_accept_language(value: str) -> typing.List[str]: def parse_accept_language(value: str) -> list[str]:
value = value.strip() value = value.strip()
if not value or value == '*': if not value or value == '*':
return [] return []
@ -72,7 +71,7 @@ def detect_language():
async def render_i18n_template(template, /, **kwargs) -> str: async def render_i18n_template(template, /, **kwargs) -> str:
langs: typing.List[str] = quart.g.langs langs: list[str] = quart.g.langs
if not langs: if not langs:
return await quart.render_template(template, **kwargs) return await quart.render_template(template, **kwargs)
names = [ names = [

View File

@ -11,7 +11,7 @@ from werkzeug.http import parse_list_header
from .app import app from .app import app
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]: def _get_first_in_list(value_list: str | None, allowed: typing.Sequence[str] = ()) -> str | None:
if not value_list: if not value_list:
return None return None
values = parse_list_header(value_list) values = parse_list_header(value_list)
@ -37,12 +37,12 @@ def local_proxy_fix(request: quart.Request):
return return
request.remote_addr = client request.remote_addr = client
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https')) scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
port: typing.Optional[int] = None port: int | None = None
if scheme: if scheme:
port = 443 if scheme == 'https' else 80 port = 443 if scheme == 'https' else 80
request.scheme = scheme request.scheme = scheme
host = _get_first_in_list(request.headers.get('X-Forwarded-Host')) host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
port_s: typing.Optional[str] port_s: str | None
if host: if host:
request.host = host request.host = host
if ':' in host and not host.endswith(']'): if ':' in host and not host.endswith(']'):
@ -71,7 +71,7 @@ class LocalProxyFixRequestHandler:
def __init__(self, orig_handle_request): def __init__(self, orig_handle_request):
self._orig_handle_request = orig_handle_request self._orig_handle_request = orig_handle_request
async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]: async def __call__(self, request: quart.Request) -> quart.Response | werkzeug.Response:
# need to patch request before url_adapter is built # need to patch request before url_adapter is built
local_proxy_fix(request) local_proxy_fix(request)
return await self._orig_handle_request(request) return await self._orig_handle_request(request)

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import ipaddress import ipaddress
import logging import logging
import typing
import quart import quart
@ -34,15 +33,15 @@ def get_client_ip() -> cptypes.IPAddress:
async def get_client_mac_if_present( async def get_client_mac_if_present(
address: typing.Optional[cptypes.IPAddress] = None, address: cptypes.IPAddress | None = None,
) -> typing.Optional[cptypes.MacAddress]: ) -> cptypes.MacAddress | None:
assert app.my_nc # for mypy assert app.my_nc # for mypy
if not address: if not address:
address = get_client_ip() address = get_client_ip()
return await app.my_nc.get_neighbor_mac(address) return await app.my_nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress: async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address) mac = await get_client_mac_if_present(address)
if mac is None: if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}") _logger.warning(f"Couldn't find MAC addresss for {address}")

View File

@ -147,7 +147,7 @@ class Connection:
async def _sender(self, cancel_scope: trio.CancelScope) -> None: async def _sender(self, cancel_scope: trio.CancelScope) -> None:
try: try:
msg: typing.Optional[capport.comm.message.Message] msg: capport.comm.message.Message | None
while True: while True:
msg = None msg = None
# make sure we send something every PING_INTERVAL # make sure we send something every PING_INTERVAL
@ -445,7 +445,7 @@ class Hub:
if conn: if conn:
await conn.send_msg(*msgs) await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None): async def broadcast(self, *msgs: capport.comm.message.Message, exclude: uuid.UUID | None = None):
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items(): for peer_id, conn in self._established.items():
if peer_id == exclude: if peer_id == exclude:

View File

@ -14,13 +14,13 @@ class Message(google.protobuf.message.Message):
def __init__( def __init__(
self, self,
*, *,
hello: typing.Optional[Hello]=None, hello: Hello | None=None,
authentication_result: typing.Optional[AuthenticationResult]=None, authentication_result: AuthenticationResult | None=None,
ping: typing.Optional[Ping]=None, ping: Ping | None=None,
mac_states: typing.Optional[MacStates]=None, mac_states: MacStates | None=None,
) -> None: ... ) -> None: ...
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ... def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
class Hello(google.protobuf.message.Message): class Hello(google.protobuf.message.Message):
@ -66,12 +66,12 @@ class Ping(google.protobuf.message.Message):
class MacStates(google.protobuf.message.Message): class MacStates(google.protobuf.message.Message):
states: typing.List[MacState] states: list[MacState]
def __init__( def __init__(
self, self,
*, *,
states: typing.List[MacState]=[], states: list[MacState]=[],
) -> None: ... ) -> None: ...
def to_message(self) -> Message: ... def to_message(self) -> Message: ...

View File

@ -3,23 +3,22 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import os.path import os.path
import typing
import yaml import yaml
_cached_config: typing.Optional[Config] = None _cached_config: Config | None = None
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class Config: class Config:
_source_filename: str _source_filename: str
controllers: typing.List[str] controllers: list[str]
server_names: typing.List[str] server_names: list[str]
comm_secret: str comm_secret: str
cookie_secret: str cookie_secret: str
venue_info_url: typing.Optional[str] venue_info_url: str | None
session_timeout: int # in seconds session_timeout: int # in seconds
api_port: int api_port: int
controller_port: int controller_port: int
@ -27,7 +26,7 @@ class Config:
debug: bool debug: bool
@staticmethod @staticmethod
def load_default_once(filename: typing.Optional[str] = None) -> Config: def load_default_once(filename: str | None = None) -> Config:
global _cached_config global _cached_config
if not _cached_config: if not _cached_config:
_cached_config = Config.load(filename) _cached_config = Config.load(filename)
@ -38,7 +37,7 @@ class Config:
return _cached_config return _cached_config
@staticmethod @staticmethod
def load(filename: typing.Optional[str] = None) -> Config: def load(filename: str | None = None) -> Config:
if filename is None: if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'): for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name): if os.path.exists(name):

View File

@ -34,7 +34,7 @@ class ControlApp(capport.comm.hub.HubApplication):
def apply_db_entries( def apply_db_entries(
self, self,
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]], entries: typing.Iterable[tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None: ) -> None:
# deploy changes to netfilter set # deploy changes to netfilter set
inserts = [] inserts = []

View File

@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
from .config import Config from .config import Config
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address] IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -53,7 +53,7 @@ class Timestamp:
return Timestamp(epoch=now) return Timestamp(epoch=now)
@staticmethod @staticmethod
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]: def from_protobuf(epoch: int) -> Timestamp | None:
if epoch: if epoch:
return Timestamp(epoch=epoch) return Timestamp(epoch=epoch)
return None return None
@ -62,7 +62,7 @@ class Timestamp:
@dataclasses.dataclass @dataclasses.dataclass
class MacPublicState: class MacPublicState:
address: IPAddress address: IPAddress
mac: typing.Optional[MacAddress] mac: MacAddress | None
allowed_remaining: int allowed_remaining: int
@staticmethod @staticmethod
@ -88,7 +88,7 @@ class MacPublicState:
return f'{hh}:{mm:02}:{ss:02}' return f'{hh}:{mm:02}:{ss:02}'
@property @property
def allowed_until(self) -> typing.Optional[datetime.datetime]: def allowed_until(self) -> datetime.datetime | None:
zone = capport.utils.zoneinfo.get_local_timezone() zone = capport.utils.zoneinfo.get_local_timezone()
now = datetime.datetime.now(zone).replace(microsecond=0) now = datetime.datetime.now(zone).replace(microsecond=0)
return now + datetime.timedelta(seconds=self.allowed_remaining) return now + datetime.timedelta(seconds=self.allowed_remaining)

View File

@ -28,11 +28,11 @@ class MacEntry:
last_change: cptypes.Timestamp last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet # only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset) # allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp] allow_until: cptypes.Timestamp | None
allowed: bool allowed: bool
@staticmethod @staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]: def parse_state(msg: capport.comm.message.MacState) -> tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6: if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short") raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address) addr = cptypes.MacAddress(raw=msg.mac_address)
@ -90,7 +90,7 @@ class MacEntry:
return cptypes.Timestamp(epoch=elc) return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed # returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int: def allowed_remaining(self, now: cptypes.Timestamp | None = None) -> int:
if not self.allowed or not self.allow_until: if not self.allowed or not self.allow_until:
return 0 return 0
if not now: if not now:
@ -98,15 +98,15 @@ class MacEntry:
assert self.allow_until assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0) return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool: def outdated(self, now: cptypes.Timestamp | None = None) -> bool:
if not now: if not now:
now = cptypes.Timestamp.now() now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there # might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> list[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = [] result: list[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates() current = capport.comm.message.MacStates()
for addr, entry in macs.items(): for addr, entry in macs.items():
state = entry.to_state(addr) state = entry.to_state(addr)
@ -121,7 +121,7 @@ def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.Li
def _serialize_mac_states_as_messages( def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry], macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]: ) -> list[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)] return [s.to_message() for s in _serialize_mac_states(macs)]
@ -143,10 +143,9 @@ class Database:
self._macs: dict[cptypes.MacAddress, MacEntry] = {} self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename self._state_filename = state_filename
self._changed_since_last_cleanup = False self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[ self._send_changes: trio.MemorySendChannel[
capport.comm.message.MacStates, capport.comm.message.MacStates | list[capport.comm.message.MacStates],
typing.List[capport.comm.message.MacStates], ] | None = None
]]] = None
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]: async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
@ -163,7 +162,7 @@ class Database:
def _drop_outdated(self) -> None: def _drop_outdated(self) -> None:
done = False done = False
while not done: while not done:
depr: typing.Set[cptypes.MacAddress] = set() depr: set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now() now = cptypes.Timestamp.now()
done = True done = True
for mac, entry in self._macs.items(): for mac, entry in self._macs.items():
@ -197,11 +196,11 @@ class Database:
await self._send_changes.send(states) await self._send_changes.send(states)
# for initial handling of all data # for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]: def entries(self) -> list[tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items()) return list(self._macs.items())
# for initial sync with new peer # for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]: def serialize(self) -> list[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs) return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict: def as_json(self) -> dict:
@ -211,14 +210,12 @@ class Database:
} }
async def _run_statefile(self) -> None: async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[ rx: trio.MemoryReceiveChannel[
capport.comm.message.MacStates, capport.comm.message.MacStates | list[capport.comm.message.MacStates],
typing.List[capport.comm.message.MacStates], ]
]] tx: trio.MemorySendChannel[
tx: trio.MemorySendChannel[typing.Union[ capport.comm.message.MacStates | list[capport.comm.message.MacStates],
capport.comm.message.MacStates, ]
typing.List[capport.comm.message.MacStates],
]]
tx, rx = trio.open_memory_channel(64) tx, rx = trio.open_memory_channel(64)
self._send_changes = tx self._send_changes = tx
@ -226,7 +223,7 @@ class Database:
filename: str = self._state_filename filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}' tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]): async def resync(all_states: list[capport.comm.message.MacStates]):
try: try:
async with await trio.open_file(tmp_filename, 'xb') as tf: async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states: for states in all_states:
@ -309,22 +306,22 @@ class PendingUpdates:
self._changes: dict[cptypes.MacAddress, MacEntry] = {} self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database self._database = database
self._closed = True self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = [] self._serialized_states: list[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = [] self._serialized: list[capport.comm.message.Message] = []
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self._changes) return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]: def changes(self) -> typing.Iterable[tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items() return self._changes.items()
@property @property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]: def serialized_states(self) -> list[capport.comm.message.MacStates]:
assert self._closed assert self._closed
return self._serialized_states return self._serialized_states
@property @property
def serialized(self) -> typing.List[capport.comm.message.Message]: def serialized(self) -> list[capport.comm.message.Message]:
assert self._closed assert self._closed
return self._serialized return self._serialized

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import ipaddress import ipaddress
import sys import sys
import typing
import trio import trio
@ -12,7 +11,7 @@ import capport.utils.nft_set
from . import cptypes from . import cptypes
def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = None, help: typing.Optional[str] = None): def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
# no labels in our names for now, always print help and type # no labels in our names for now, always print help and type
if help: if help:
print(f'# HELP {name} {help}') print(f'# HELP {name} {help}')
@ -25,11 +24,11 @@ def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = No
async def amain(client_ifname: str): async def amain(client_ifname: str):
ns = capport.utils.nft_set.NftSet() ns = capport.utils.nft_set.NftSet()
captive_allowed_entries: typing.Set[cptypes.MacAddress] = { captive_allowed_entries: set[cptypes.MacAddress] = {
entry['mac'] entry['mac']
for entry in ns.list() for entry in ns.list()
} }
seen_allowed_entries: typing.Set[cptypes.MacAddress] = set() seen_allowed_entries: set[cptypes.MacAddress] = set()
total_ipv4 = 0 total_ipv4 = 0
total_ipv6 = 0 total_ipv6 = 0
unique_clients = set() unique_clients = set()

View File

@ -30,7 +30,7 @@ class NeighborController:
*, *,
index: int = 0, # interface index index: int = 0, # interface index
flags: int = 0, flags: int = 0,
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]: ) -> pyroute2.iproute.linux.ndmsg.ndmsg | None:
if not index: if not index:
route = await self.get_route(address) route = await self.get_route(address)
if route is None: if route is None:
@ -49,7 +49,7 @@ class NeighborController:
*, *,
index: int = 0, # interface index index: int = 0, # interface index
flags: int = 0, flags: int = 0,
) -> typing.Optional[cptypes.MacAddress]: ) -> cptypes.MacAddress | None:
neigh = await self.get_neighbor(address, index=index, flags=flags) neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None: if neigh is None:
return None return None
@ -61,7 +61,7 @@ class NeighborController:
async def get_route( async def get_route(
self, self,
address: cptypes.IPAddress, address: cptypes.IPAddress,
) -> typing.Optional[pyroute2.iproute.linux.rtmsg]: ) -> pyroute2.iproute.linux.rtmsg | None:
try: try:
return self.ip.route('get', dst=str(address))[0] return self.ip.route('get', dst=str(address))[0]
except pyroute2.netlink.exceptions.NetlinkError as e: except pyroute2.netlink.exceptions.NetlinkError as e:
@ -72,7 +72,7 @@ class NeighborController:
async def dump_neighbors( async def dump_neighbors(
self, self,
interface: str, interface: str,
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: ) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface) ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)

View File

@ -12,7 +12,7 @@ from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX" NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]: def _from_msec(msecs: int | None) -> float | None:
# to seconds # to seconds
if msecs is None: if msecs is None:
return None return None
@ -27,7 +27,7 @@ class NftSet:
@staticmethod @staticmethod
def _set_elem( def _set_elem(
mac: cptypes.MacAddress, mac: cptypes.MacAddress,
timeout: typing.Optional[typing.Union[int, float]] = None, timeout: int | float | None = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem: ) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = { attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict( 'NFTA_SET_ELEM_KEY': dict(
@ -40,7 +40,7 @@ class NftSet:
def _bulk_insert( def _bulk_insert(
self, self,
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]], entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
) -> None: ) -> None:
ser_entries = [ ser_entries = [
self._set_elem(mac) self._set_elem(mac)
@ -85,13 +85,13 @@ class NftSet:
), ),
) )
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None: def bulk_insert(self, entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]]) -> None:
# limit chunk size # limit chunk size
while len(entries) > 0: while len(entries) > 0:
self._bulk_insert(entries[:1024]) self._bulk_insert(entries[:1024])
entries = entries[1024:] entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None: def insert(self, mac: cptypes.MacAddress, timeout: int | float) -> None:
self.bulk_insert([(mac, timeout)]) self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None: def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:

View File

@ -37,7 +37,7 @@ def _monkey_patch_pyroute2():
self['header'].update(header) self['header'].update(header)
return res return res
def overwrite_methods(cls: typing.Type) -> None: def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
if cls.setvalue is _orig_setvalue: if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__(): for subcls in cls.__subclasses__():
@ -49,7 +49,7 @@ def _monkey_patch_pyroute2():
_monkey_patch_pyroute2() _monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class() msg = msg_class()
for key, value in header.items(): for key, value in header.items():
msg['header'][key] = value msg['header'][key] = value
@ -79,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy policy: dict[int, type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
@ -110,7 +110,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_class: type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields) msg = _build(msg_class, attrs=attrs, **fields)
@ -178,7 +178,7 @@ class NFTTransaction:
self._msgs.append(msg) self._msgs.append(msg)
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] msg_class: type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set! msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict( header = dict(

View File

@ -19,7 +19,7 @@ def _check_watchdog_pid() -> bool:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]: async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
target = os.environ.pop('NOTIFY_SOCKET', None) target = os.environ.pop('NOTIFY_SOCKET', None)
ns: typing.Optional[trio.socket.SocketType] = None ns: trio.socket.SocketType | None = None
watchdog_usec: int = 0 watchdog_usec: int = 0
if target: if target:
if target.startswith('@'): if target.startswith('@'):
@ -44,7 +44,7 @@ async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
class SdNotify: class SdNotify:
def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None: def __init__(self, *, _ns: trio.socket.SocketType | None) -> None:
self._ns = _ns self._ns = _ns
def is_connected(self) -> bool: def is_connected(self) -> bool:

View File

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
import typing
import zoneinfo import zoneinfo
_zoneinfo: typing.Optional[zoneinfo.ZoneInfo] = None _zoneinfo: zoneinfo.ZoneInfo | None = None
def get_local_timezone(): def get_local_timezone():