2
0

replace some typing.* constructs with native ones

This commit is contained in:
Stefan Bühler 2023-11-15 10:35:38 +01:00
parent 29d1a3226a
commit ced589f28a
16 changed files with 77 additions and 85 deletions

View File

@ -42,10 +42,10 @@ class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):
class MyQuartApp(quart_trio.QuartTrio):
my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
my_hub: typing.Optional[capport.comm.hub.Hub] = None
my_nc: capport.utils.ipneigh.NeighborController | None = None
my_hub: capport.comm.hub.Hub | None = None
my_config: capport.config.Config
custom_loader: typing.Optional[jinja2.FileSystemLoader] = None
custom_loader: jinja2.FileSystemLoader | None = None
def __init__(self, import_name: str, **kwargs) -> None:
self.my_config = capport.config.Config.load_default_once()

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os.path
import re
import typing
import quart
@ -11,7 +10,7 @@ from .app import app
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
def parse_accept_language(value: str) -> typing.List[str]:
def parse_accept_language(value: str) -> list[str]:
value = value.strip()
if not value or value == '*':
return []
@ -72,7 +71,7 @@ def detect_language():
async def render_i18n_template(template, /, **kwargs) -> str:
langs: typing.List[str] = quart.g.langs
langs: list[str] = quart.g.langs
if not langs:
return await quart.render_template(template, **kwargs)
names = [

View File

@ -11,7 +11,7 @@ from werkzeug.http import parse_list_header
from .app import app
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]:
def _get_first_in_list(value_list: str | None, allowed: typing.Sequence[str] = ()) -> str | None:
if not value_list:
return None
values = parse_list_header(value_list)
@ -37,12 +37,12 @@ def local_proxy_fix(request: quart.Request):
return
request.remote_addr = client
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
port: typing.Optional[int] = None
port: int | None = None
if scheme:
port = 443 if scheme == 'https' else 80
request.scheme = scheme
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
port_s: typing.Optional[str]
port_s: str | None
if host:
request.host = host
if ':' in host and not host.endswith(']'):
@ -71,7 +71,7 @@ class LocalProxyFixRequestHandler:
def __init__(self, orig_handle_request):
self._orig_handle_request = orig_handle_request
async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]:
async def __call__(self, request: quart.Request) -> quart.Response | werkzeug.Response:
# need to patch request before url_adapter is built
local_proxy_fix(request)
return await self._orig_handle_request(request)

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import ipaddress
import logging
import typing
import quart
@ -34,15 +33,15 @@ def get_client_ip() -> cptypes.IPAddress:
async def get_client_mac_if_present(
address: typing.Optional[cptypes.IPAddress] = None,
) -> typing.Optional[cptypes.MacAddress]:
address: cptypes.IPAddress | None = None,
) -> cptypes.MacAddress | None:
assert app.my_nc # for mypy
if not address:
address = get_client_ip()
return await app.my_nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress:
async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")

View File

@ -147,7 +147,7 @@ class Connection:
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
try:
msg: typing.Optional[capport.comm.message.Message]
msg: capport.comm.message.Message | None
while True:
msg = None
# make sure we send something every PING_INTERVAL
@ -445,7 +445,7 @@ class Hub:
if conn:
await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None):
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: uuid.UUID | None = None):
async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items():
if peer_id == exclude:

View File

@ -14,13 +14,13 @@ class Message(google.protobuf.message.Message):
def __init__(
self,
*,
hello: typing.Optional[Hello]=None,
authentication_result: typing.Optional[AuthenticationResult]=None,
ping: typing.Optional[Ping]=None,
mac_states: typing.Optional[MacStates]=None,
hello: Hello | None=None,
authentication_result: AuthenticationResult | None=None,
ping: Ping | None=None,
mac_states: MacStates | None=None,
) -> None: ...
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
class Hello(google.protobuf.message.Message):
@ -66,12 +66,12 @@ class Ping(google.protobuf.message.Message):
class MacStates(google.protobuf.message.Message):
states: typing.List[MacState]
states: list[MacState]
def __init__(
self,
*,
states: typing.List[MacState]=[],
states: list[MacState]=[],
) -> None: ...
def to_message(self) -> Message: ...

View File

@ -3,23 +3,22 @@ from __future__ import annotations
import dataclasses
import logging
import os.path
import typing
import yaml
_cached_config: typing.Optional[Config] = None
_cached_config: Config | None = None
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Config:
_source_filename: str
controllers: typing.List[str]
server_names: typing.List[str]
controllers: list[str]
server_names: list[str]
comm_secret: str
cookie_secret: str
venue_info_url: typing.Optional[str]
venue_info_url: str | None
session_timeout: int # in seconds
api_port: int
controller_port: int
@ -27,7 +26,7 @@ class Config:
debug: bool
@staticmethod
def load_default_once(filename: typing.Optional[str] = None) -> Config:
def load_default_once(filename: str | None = None) -> Config:
global _cached_config
if not _cached_config:
_cached_config = Config.load(filename)
@ -38,7 +37,7 @@ class Config:
return _cached_config
@staticmethod
def load(filename: typing.Optional[str] = None) -> Config:
def load(filename: str | None = None) -> Config:
if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name):

View File

@ -34,7 +34,7 @@ class ControlApp(capport.comm.hub.HubApplication):
def apply_db_entries(
self,
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
entries: typing.Iterable[tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None:
# deploy changes to netfilter set
inserts = []

View File

@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
from .config import Config
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
@dataclasses.dataclass(frozen=True)
@ -53,7 +53,7 @@ class Timestamp:
return Timestamp(epoch=now)
@staticmethod
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
def from_protobuf(epoch: int) -> Timestamp | None:
if epoch:
return Timestamp(epoch=epoch)
return None
@ -62,7 +62,7 @@ class Timestamp:
@dataclasses.dataclass
class MacPublicState:
address: IPAddress
mac: typing.Optional[MacAddress]
mac: MacAddress | None
allowed_remaining: int
@staticmethod
@ -88,7 +88,7 @@ class MacPublicState:
return f'{hh}:{mm:02}:{ss:02}'
@property
def allowed_until(self) -> typing.Optional[datetime.datetime]:
def allowed_until(self) -> datetime.datetime | None:
zone = capport.utils.zoneinfo.get_local_timezone()
now = datetime.datetime.now(zone).replace(microsecond=0)
return now + datetime.timedelta(seconds=self.allowed_remaining)

View File

@ -28,11 +28,11 @@ class MacEntry:
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allow_until: cptypes.Timestamp | None
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
def parse_state(msg: capport.comm.message.MacState) -> tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
@ -90,7 +90,7 @@ class MacEntry:
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
def allowed_remaining(self, now: cptypes.Timestamp | None = None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
@ -98,15 +98,15 @@ class MacEntry:
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
def outdated(self, now: cptypes.Timestamp | None = None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> list[capport.comm.message.MacStates]:
result: list[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
@ -121,7 +121,7 @@ def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.Li
def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]:
) -> list[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
@ -143,10 +143,9 @@ class Database:
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]] = None
self._send_changes: trio.MemorySendChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
] | None = None
@contextlib.asynccontextmanager
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
@ -163,7 +162,7 @@ class Database:
def _drop_outdated(self) -> None:
done = False
while not done:
depr: typing.Set[cptypes.MacAddress] = set()
depr: set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now()
done = True
for mac, entry in self._macs.items():
@ -197,11 +196,11 @@ class Database:
await self._send_changes.send(states)
# for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
def entries(self) -> list[tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items())
# for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]:
def serialize(self) -> list[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
@ -211,14 +210,12 @@ class Database:
}
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx: trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
rx: trio.MemoryReceiveChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
]
tx: trio.MemorySendChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
@ -226,7 +223,7 @@ class Database:
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
async def resync(all_states: list[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states:
@ -309,22 +306,22 @@ class PendingUpdates:
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = []
self._serialized_states: list[capport.comm.message.MacStates] = []
self._serialized: list[capport.comm.message.Message] = []
def __bool__(self) -> bool:
return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
def changes(self) -> typing.Iterable[tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items()
@property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
def serialized_states(self) -> list[capport.comm.message.MacStates]:
assert self._closed
return self._serialized_states
@property
def serialized(self) -> typing.List[capport.comm.message.Message]:
def serialized(self) -> list[capport.comm.message.Message]:
assert self._closed
return self._serialized

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import ipaddress
import sys
import typing
import trio
@ -12,7 +11,7 @@ import capport.utils.nft_set
from . import cptypes
def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = None, help: typing.Optional[str] = None):
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
# no labels in our names for now, always print help and type
if help:
print(f'# HELP {name} {help}')
@ -25,11 +24,11 @@ def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = No
async def amain(client_ifname: str):
ns = capport.utils.nft_set.NftSet()
captive_allowed_entries: typing.Set[cptypes.MacAddress] = {
captive_allowed_entries: set[cptypes.MacAddress] = {
entry['mac']
for entry in ns.list()
}
seen_allowed_entries: typing.Set[cptypes.MacAddress] = set()
seen_allowed_entries: set[cptypes.MacAddress] = set()
total_ipv4 = 0
total_ipv6 = 0
unique_clients = set()

View File

@ -30,7 +30,7 @@ class NeighborController:
*,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
) -> pyroute2.iproute.linux.ndmsg.ndmsg | None:
if not index:
route = await self.get_route(address)
if route is None:
@ -49,7 +49,7 @@ class NeighborController:
*,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[cptypes.MacAddress]:
) -> cptypes.MacAddress | None:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
@ -61,7 +61,7 @@ class NeighborController:
async def get_route(
self,
address: cptypes.IPAddress,
) -> typing.Optional[pyroute2.iproute.linux.rtmsg]:
) -> pyroute2.iproute.linux.rtmsg | None:
try:
return self.ip.route('get', dst=str(address))[0]
except pyroute2.netlink.exceptions.NetlinkError as e:
@ -72,7 +72,7 @@ class NeighborController:
async def dump_neighbors(
self,
interface: str,
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)

View File

@ -12,7 +12,7 @@ from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
def _from_msec(msecs: int | None) -> float | None:
# to seconds
if msecs is None:
return None
@ -27,7 +27,7 @@ class NftSet:
@staticmethod
def _set_elem(
mac: cptypes.MacAddress,
timeout: typing.Optional[typing.Union[int, float]] = None,
timeout: int | float | None = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
@ -40,7 +40,7 @@ class NftSet:
def _bulk_insert(
self,
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
) -> None:
ser_entries = [
self._set_elem(mac)
@ -85,13 +85,13 @@ class NftSet:
),
)
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
def bulk_insert(self, entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_insert(entries[:1024])
entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
def insert(self, mac: cptypes.MacAddress, timeout: int | float) -> None:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:

View File

@ -37,7 +37,7 @@ def _monkey_patch_pyroute2():
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
@ -49,7 +49,7 @@ def _monkey_patch_pyroute2():
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
@ -79,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
policy: dict[int, type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
@ -110,7 +110,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_class: type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
@ -178,7 +178,7 @@ class NFTTransaction:
self._msgs.append(msg)
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_class: type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(

View File

@ -19,7 +19,7 @@ def _check_watchdog_pid() -> bool:
@contextlib.asynccontextmanager
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
target = os.environ.pop('NOTIFY_SOCKET', None)
ns: typing.Optional[trio.socket.SocketType] = None
ns: trio.socket.SocketType | None = None
watchdog_usec: int = 0
if target:
if target.startswith('@'):
@ -44,7 +44,7 @@ async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
class SdNotify:
def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None:
def __init__(self, *, _ns: trio.socket.SocketType | None) -> None:
self._ns = _ns
def is_connected(self) -> bool:

View File

@ -1,10 +1,9 @@
from __future__ import annotations
import typing
import zoneinfo
_zoneinfo: typing.Optional[zoneinfo.ZoneInfo] = None
_zoneinfo: zoneinfo.ZoneInfo | None = None
def get_local_timezone():