replace some typing.* constructs with native ones
This commit is contained in:
parent
29d1a3226a
commit
ced589f28a
@ -42,10 +42,10 @@ class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):
|
|||||||
|
|
||||||
|
|
||||||
class MyQuartApp(quart_trio.QuartTrio):
|
class MyQuartApp(quart_trio.QuartTrio):
|
||||||
my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
|
my_nc: capport.utils.ipneigh.NeighborController | None = None
|
||||||
my_hub: typing.Optional[capport.comm.hub.Hub] = None
|
my_hub: capport.comm.hub.Hub | None = None
|
||||||
my_config: capport.config.Config
|
my_config: capport.config.Config
|
||||||
custom_loader: typing.Optional[jinja2.FileSystemLoader] = None
|
custom_loader: jinja2.FileSystemLoader | None = None
|
||||||
|
|
||||||
def __init__(self, import_name: str, **kwargs) -> None:
|
def __init__(self, import_name: str, **kwargs) -> None:
|
||||||
self.my_config = capport.config.Config.load_default_once()
|
self.my_config = capport.config.Config.load_default_once()
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
import typing
|
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from .app import app
|
|||||||
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
|
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
|
||||||
|
|
||||||
|
|
||||||
def parse_accept_language(value: str) -> typing.List[str]:
|
def parse_accept_language(value: str) -> list[str]:
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
if not value or value == '*':
|
if not value or value == '*':
|
||||||
return []
|
return []
|
||||||
@ -72,7 +71,7 @@ def detect_language():
|
|||||||
|
|
||||||
|
|
||||||
async def render_i18n_template(template, /, **kwargs) -> str:
|
async def render_i18n_template(template, /, **kwargs) -> str:
|
||||||
langs: typing.List[str] = quart.g.langs
|
langs: list[str] = quart.g.langs
|
||||||
if not langs:
|
if not langs:
|
||||||
return await quart.render_template(template, **kwargs)
|
return await quart.render_template(template, **kwargs)
|
||||||
names = [
|
names = [
|
||||||
|
@ -11,7 +11,7 @@ from werkzeug.http import parse_list_header
|
|||||||
from .app import app
|
from .app import app
|
||||||
|
|
||||||
|
|
||||||
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]:
|
def _get_first_in_list(value_list: str | None, allowed: typing.Sequence[str] = ()) -> str | None:
|
||||||
if not value_list:
|
if not value_list:
|
||||||
return None
|
return None
|
||||||
values = parse_list_header(value_list)
|
values = parse_list_header(value_list)
|
||||||
@ -37,12 +37,12 @@ def local_proxy_fix(request: quart.Request):
|
|||||||
return
|
return
|
||||||
request.remote_addr = client
|
request.remote_addr = client
|
||||||
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
|
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
|
||||||
port: typing.Optional[int] = None
|
port: int | None = None
|
||||||
if scheme:
|
if scheme:
|
||||||
port = 443 if scheme == 'https' else 80
|
port = 443 if scheme == 'https' else 80
|
||||||
request.scheme = scheme
|
request.scheme = scheme
|
||||||
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
|
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
|
||||||
port_s: typing.Optional[str]
|
port_s: str | None
|
||||||
if host:
|
if host:
|
||||||
request.host = host
|
request.host = host
|
||||||
if ':' in host and not host.endswith(']'):
|
if ':' in host and not host.endswith(']'):
|
||||||
@ -71,7 +71,7 @@ class LocalProxyFixRequestHandler:
|
|||||||
def __init__(self, orig_handle_request):
|
def __init__(self, orig_handle_request):
|
||||||
self._orig_handle_request = orig_handle_request
|
self._orig_handle_request = orig_handle_request
|
||||||
|
|
||||||
async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]:
|
async def __call__(self, request: quart.Request) -> quart.Response | werkzeug.Response:
|
||||||
# need to patch request before url_adapter is built
|
# need to patch request before url_adapter is built
|
||||||
local_proxy_fix(request)
|
local_proxy_fix(request)
|
||||||
return await self._orig_handle_request(request)
|
return await self._orig_handle_request(request)
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
|
|
||||||
@ -34,15 +33,15 @@ def get_client_ip() -> cptypes.IPAddress:
|
|||||||
|
|
||||||
|
|
||||||
async def get_client_mac_if_present(
|
async def get_client_mac_if_present(
|
||||||
address: typing.Optional[cptypes.IPAddress] = None,
|
address: cptypes.IPAddress | None = None,
|
||||||
) -> typing.Optional[cptypes.MacAddress]:
|
) -> cptypes.MacAddress | None:
|
||||||
assert app.my_nc # for mypy
|
assert app.my_nc # for mypy
|
||||||
if not address:
|
if not address:
|
||||||
address = get_client_ip()
|
address = get_client_ip()
|
||||||
return await app.my_nc.get_neighbor_mac(address)
|
return await app.my_nc.get_neighbor_mac(address)
|
||||||
|
|
||||||
|
|
||||||
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress:
|
async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.MacAddress:
|
||||||
mac = await get_client_mac_if_present(address)
|
mac = await get_client_mac_if_present(address)
|
||||||
if mac is None:
|
if mac is None:
|
||||||
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
||||||
|
@ -147,7 +147,7 @@ class Connection:
|
|||||||
|
|
||||||
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
|
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
|
||||||
try:
|
try:
|
||||||
msg: typing.Optional[capport.comm.message.Message]
|
msg: capport.comm.message.Message | None
|
||||||
while True:
|
while True:
|
||||||
msg = None
|
msg = None
|
||||||
# make sure we send something every PING_INTERVAL
|
# make sure we send something every PING_INTERVAL
|
||||||
@ -445,7 +445,7 @@ class Hub:
|
|||||||
if conn:
|
if conn:
|
||||||
await conn.send_msg(*msgs)
|
await conn.send_msg(*msgs)
|
||||||
|
|
||||||
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None):
|
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: uuid.UUID | None = None):
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for peer_id, conn in self._established.items():
|
for peer_id, conn in self._established.items():
|
||||||
if peer_id == exclude:
|
if peer_id == exclude:
|
||||||
|
@ -14,13 +14,13 @@ class Message(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
hello: typing.Optional[Hello]=None,
|
hello: Hello | None=None,
|
||||||
authentication_result: typing.Optional[AuthenticationResult]=None,
|
authentication_result: AuthenticationResult | None=None,
|
||||||
ping: typing.Optional[Ping]=None,
|
ping: Ping | None=None,
|
||||||
mac_states: typing.Optional[MacStates]=None,
|
mac_states: MacStates | None=None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
|
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
|
||||||
|
|
||||||
|
|
||||||
class Hello(google.protobuf.message.Message):
|
class Hello(google.protobuf.message.Message):
|
||||||
@ -66,12 +66,12 @@ class Ping(google.protobuf.message.Message):
|
|||||||
|
|
||||||
|
|
||||||
class MacStates(google.protobuf.message.Message):
|
class MacStates(google.protobuf.message.Message):
|
||||||
states: typing.List[MacState]
|
states: list[MacState]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
states: typing.List[MacState]=[],
|
states: list[MacState]=[],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_message(self) -> Message: ...
|
def to_message(self) -> Message: ...
|
||||||
|
@ -3,23 +3,22 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
import typing
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
_cached_config: typing.Optional[Config] = None
|
_cached_config: Config | None = None
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Config:
|
class Config:
|
||||||
_source_filename: str
|
_source_filename: str
|
||||||
controllers: typing.List[str]
|
controllers: list[str]
|
||||||
server_names: typing.List[str]
|
server_names: list[str]
|
||||||
comm_secret: str
|
comm_secret: str
|
||||||
cookie_secret: str
|
cookie_secret: str
|
||||||
venue_info_url: typing.Optional[str]
|
venue_info_url: str | None
|
||||||
session_timeout: int # in seconds
|
session_timeout: int # in seconds
|
||||||
api_port: int
|
api_port: int
|
||||||
controller_port: int
|
controller_port: int
|
||||||
@ -27,7 +26,7 @@ class Config:
|
|||||||
debug: bool
|
debug: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_default_once(filename: typing.Optional[str] = None) -> Config:
|
def load_default_once(filename: str | None = None) -> Config:
|
||||||
global _cached_config
|
global _cached_config
|
||||||
if not _cached_config:
|
if not _cached_config:
|
||||||
_cached_config = Config.load(filename)
|
_cached_config = Config.load(filename)
|
||||||
@ -38,7 +37,7 @@ class Config:
|
|||||||
return _cached_config
|
return _cached_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(filename: typing.Optional[str] = None) -> Config:
|
def load(filename: str | None = None) -> Config:
|
||||||
if filename is None:
|
if filename is None:
|
||||||
for name in ('capport.yaml', '/etc/capport.yaml'):
|
for name in ('capport.yaml', '/etc/capport.yaml'):
|
||||||
if os.path.exists(name):
|
if os.path.exists(name):
|
||||||
|
@ -34,7 +34,7 @@ class ControlApp(capport.comm.hub.HubApplication):
|
|||||||
|
|
||||||
def apply_db_entries(
|
def apply_db_entries(
|
||||||
self,
|
self,
|
||||||
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
|
entries: typing.Iterable[tuple[cptypes.MacAddress, capport.database.MacEntry]],
|
||||||
) -> None:
|
) -> None:
|
||||||
# deploy changes to netfilter set
|
# deploy changes to netfilter set
|
||||||
inserts = []
|
inserts = []
|
||||||
|
@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
|
|||||||
from .config import Config
|
from .config import Config
|
||||||
|
|
||||||
|
|
||||||
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
@ -53,7 +53,7 @@ class Timestamp:
|
|||||||
return Timestamp(epoch=now)
|
return Timestamp(epoch=now)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
|
def from_protobuf(epoch: int) -> Timestamp | None:
|
||||||
if epoch:
|
if epoch:
|
||||||
return Timestamp(epoch=epoch)
|
return Timestamp(epoch=epoch)
|
||||||
return None
|
return None
|
||||||
@ -62,7 +62,7 @@ class Timestamp:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MacPublicState:
|
class MacPublicState:
|
||||||
address: IPAddress
|
address: IPAddress
|
||||||
mac: typing.Optional[MacAddress]
|
mac: MacAddress | None
|
||||||
allowed_remaining: int
|
allowed_remaining: int
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -88,7 +88,7 @@ class MacPublicState:
|
|||||||
return f'{hh}:{mm:02}:{ss:02}'
|
return f'{hh}:{mm:02}:{ss:02}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allowed_until(self) -> typing.Optional[datetime.datetime]:
|
def allowed_until(self) -> datetime.datetime | None:
|
||||||
zone = capport.utils.zoneinfo.get_local_timezone()
|
zone = capport.utils.zoneinfo.get_local_timezone()
|
||||||
now = datetime.datetime.now(zone).replace(microsecond=0)
|
now = datetime.datetime.now(zone).replace(microsecond=0)
|
||||||
return now + datetime.timedelta(seconds=self.allowed_remaining)
|
return now + datetime.timedelta(seconds=self.allowed_remaining)
|
||||||
|
@ -28,11 +28,11 @@ class MacEntry:
|
|||||||
last_change: cptypes.Timestamp
|
last_change: cptypes.Timestamp
|
||||||
# only if allowed is true and allow_until is set the device can communicate with the internet
|
# only if allowed is true and allow_until is set the device can communicate with the internet
|
||||||
# allow_until must not go backwards (and not get unset)
|
# allow_until must not go backwards (and not get unset)
|
||||||
allow_until: typing.Optional[cptypes.Timestamp]
|
allow_until: cptypes.Timestamp | None
|
||||||
allowed: bool
|
allowed: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
|
def parse_state(msg: capport.comm.message.MacState) -> tuple[cptypes.MacAddress, MacEntry]:
|
||||||
if len(msg.mac_address) < 6:
|
if len(msg.mac_address) < 6:
|
||||||
raise Exception("Invalid MacState: mac_address too short")
|
raise Exception("Invalid MacState: mac_address too short")
|
||||||
addr = cptypes.MacAddress(raw=msg.mac_address)
|
addr = cptypes.MacAddress(raw=msg.mac_address)
|
||||||
@ -90,7 +90,7 @@ class MacEntry:
|
|||||||
return cptypes.Timestamp(epoch=elc)
|
return cptypes.Timestamp(epoch=elc)
|
||||||
|
|
||||||
# returns 0 if not allowed
|
# returns 0 if not allowed
|
||||||
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
|
def allowed_remaining(self, now: cptypes.Timestamp | None = None) -> int:
|
||||||
if not self.allowed or not self.allow_until:
|
if not self.allowed or not self.allow_until:
|
||||||
return 0
|
return 0
|
||||||
if not now:
|
if not now:
|
||||||
@ -98,15 +98,15 @@ class MacEntry:
|
|||||||
assert self.allow_until
|
assert self.allow_until
|
||||||
return max(self.allow_until.epoch - now.epoch, 0)
|
return max(self.allow_until.epoch - now.epoch, 0)
|
||||||
|
|
||||||
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
|
def outdated(self, now: cptypes.Timestamp | None = None) -> bool:
|
||||||
if not now:
|
if not now:
|
||||||
now = cptypes.Timestamp.now()
|
now = cptypes.Timestamp.now()
|
||||||
return now.epoch > self.timeout().epoch
|
return now.epoch > self.timeout().epoch
|
||||||
|
|
||||||
|
|
||||||
# might use this to serialize into file - don't need Message variant there
|
# might use this to serialize into file - don't need Message variant there
|
||||||
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
|
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> list[capport.comm.message.MacStates]:
|
||||||
result: typing.List[capport.comm.message.MacStates] = []
|
result: list[capport.comm.message.MacStates] = []
|
||||||
current = capport.comm.message.MacStates()
|
current = capport.comm.message.MacStates()
|
||||||
for addr, entry in macs.items():
|
for addr, entry in macs.items():
|
||||||
state = entry.to_state(addr)
|
state = entry.to_state(addr)
|
||||||
@ -121,7 +121,7 @@ def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.Li
|
|||||||
|
|
||||||
def _serialize_mac_states_as_messages(
|
def _serialize_mac_states_as_messages(
|
||||||
macs: dict[cptypes.MacAddress, MacEntry],
|
macs: dict[cptypes.MacAddress, MacEntry],
|
||||||
) -> typing.List[capport.comm.message.Message]:
|
) -> list[capport.comm.message.Message]:
|
||||||
return [s.to_message() for s in _serialize_mac_states(macs)]
|
return [s.to_message() for s in _serialize_mac_states(macs)]
|
||||||
|
|
||||||
|
|
||||||
@ -143,10 +143,9 @@ class Database:
|
|||||||
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
|
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
|
||||||
self._state_filename = state_filename
|
self._state_filename = state_filename
|
||||||
self._changed_since_last_cleanup = False
|
self._changed_since_last_cleanup = False
|
||||||
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
|
self._send_changes: trio.MemorySendChannel[
|
||||||
capport.comm.message.MacStates,
|
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
||||||
typing.List[capport.comm.message.MacStates],
|
] | None = None
|
||||||
]]] = None
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
|
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
|
||||||
@ -163,7 +162,7 @@ class Database:
|
|||||||
def _drop_outdated(self) -> None:
|
def _drop_outdated(self) -> None:
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
depr: typing.Set[cptypes.MacAddress] = set()
|
depr: set[cptypes.MacAddress] = set()
|
||||||
now = cptypes.Timestamp.now()
|
now = cptypes.Timestamp.now()
|
||||||
done = True
|
done = True
|
||||||
for mac, entry in self._macs.items():
|
for mac, entry in self._macs.items():
|
||||||
@ -197,11 +196,11 @@ class Database:
|
|||||||
await self._send_changes.send(states)
|
await self._send_changes.send(states)
|
||||||
|
|
||||||
# for initial handling of all data
|
# for initial handling of all data
|
||||||
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
def entries(self) -> list[tuple[cptypes.MacAddress, MacEntry]]:
|
||||||
return list(self._macs.items())
|
return list(self._macs.items())
|
||||||
|
|
||||||
# for initial sync with new peer
|
# for initial sync with new peer
|
||||||
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
def serialize(self) -> list[capport.comm.message.Message]:
|
||||||
return _serialize_mac_states_as_messages(self._macs)
|
return _serialize_mac_states_as_messages(self._macs)
|
||||||
|
|
||||||
def as_json(self) -> dict:
|
def as_json(self) -> dict:
|
||||||
@ -211,14 +210,12 @@ class Database:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def _run_statefile(self) -> None:
|
async def _run_statefile(self) -> None:
|
||||||
rx: trio.MemoryReceiveChannel[typing.Union[
|
rx: trio.MemoryReceiveChannel[
|
||||||
capport.comm.message.MacStates,
|
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
||||||
typing.List[capport.comm.message.MacStates],
|
]
|
||||||
]]
|
tx: trio.MemorySendChannel[
|
||||||
tx: trio.MemorySendChannel[typing.Union[
|
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
||||||
capport.comm.message.MacStates,
|
]
|
||||||
typing.List[capport.comm.message.MacStates],
|
|
||||||
]]
|
|
||||||
tx, rx = trio.open_memory_channel(64)
|
tx, rx = trio.open_memory_channel(64)
|
||||||
self._send_changes = tx
|
self._send_changes = tx
|
||||||
|
|
||||||
@ -226,7 +223,7 @@ class Database:
|
|||||||
filename: str = self._state_filename
|
filename: str = self._state_filename
|
||||||
tmp_filename = f'{filename}.new-{os.getpid()}'
|
tmp_filename = f'{filename}.new-{os.getpid()}'
|
||||||
|
|
||||||
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
|
async def resync(all_states: list[capport.comm.message.MacStates]):
|
||||||
try:
|
try:
|
||||||
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
||||||
for states in all_states:
|
for states in all_states:
|
||||||
@ -309,22 +306,22 @@ class PendingUpdates:
|
|||||||
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
|
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
|
||||||
self._database = database
|
self._database = database
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
|
self._serialized_states: list[capport.comm.message.MacStates] = []
|
||||||
self._serialized: typing.List[capport.comm.message.Message] = []
|
self._serialized: list[capport.comm.message.Message] = []
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return bool(self._changes)
|
return bool(self._changes)
|
||||||
|
|
||||||
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
def changes(self) -> typing.Iterable[tuple[cptypes.MacAddress, MacEntry]]:
|
||||||
return self._changes.items()
|
return self._changes.items()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
|
def serialized_states(self) -> list[capport.comm.message.MacStates]:
|
||||||
assert self._closed
|
assert self._closed
|
||||||
return self._serialized_states
|
return self._serialized_states
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serialized(self) -> typing.List[capport.comm.message.Message]:
|
def serialized(self) -> list[capport.comm.message.Message]:
|
||||||
assert self._closed
|
assert self._closed
|
||||||
return self._serialized
|
return self._serialized
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import sys
|
import sys
|
||||||
import typing
|
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ import capport.utils.nft_set
|
|||||||
from . import cptypes
|
from . import cptypes
|
||||||
|
|
||||||
|
|
||||||
def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = None, help: typing.Optional[str] = None):
|
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
|
||||||
# no labels in our names for now, always print help and type
|
# no labels in our names for now, always print help and type
|
||||||
if help:
|
if help:
|
||||||
print(f'# HELP {name} {help}')
|
print(f'# HELP {name} {help}')
|
||||||
@ -25,11 +24,11 @@ def print_metric(name: str, mtype: str, value, *, now: typing.Optional[int] = No
|
|||||||
|
|
||||||
async def amain(client_ifname: str):
|
async def amain(client_ifname: str):
|
||||||
ns = capport.utils.nft_set.NftSet()
|
ns = capport.utils.nft_set.NftSet()
|
||||||
captive_allowed_entries: typing.Set[cptypes.MacAddress] = {
|
captive_allowed_entries: set[cptypes.MacAddress] = {
|
||||||
entry['mac']
|
entry['mac']
|
||||||
for entry in ns.list()
|
for entry in ns.list()
|
||||||
}
|
}
|
||||||
seen_allowed_entries: typing.Set[cptypes.MacAddress] = set()
|
seen_allowed_entries: set[cptypes.MacAddress] = set()
|
||||||
total_ipv4 = 0
|
total_ipv4 = 0
|
||||||
total_ipv6 = 0
|
total_ipv6 = 0
|
||||||
unique_clients = set()
|
unique_clients = set()
|
||||||
|
@ -30,7 +30,7 @@ class NeighborController:
|
|||||||
*,
|
*,
|
||||||
index: int = 0, # interface index
|
index: int = 0, # interface index
|
||||||
flags: int = 0,
|
flags: int = 0,
|
||||||
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
|
) -> pyroute2.iproute.linux.ndmsg.ndmsg | None:
|
||||||
if not index:
|
if not index:
|
||||||
route = await self.get_route(address)
|
route = await self.get_route(address)
|
||||||
if route is None:
|
if route is None:
|
||||||
@ -49,7 +49,7 @@ class NeighborController:
|
|||||||
*,
|
*,
|
||||||
index: int = 0, # interface index
|
index: int = 0, # interface index
|
||||||
flags: int = 0,
|
flags: int = 0,
|
||||||
) -> typing.Optional[cptypes.MacAddress]:
|
) -> cptypes.MacAddress | None:
|
||||||
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||||
if neigh is None:
|
if neigh is None:
|
||||||
return None
|
return None
|
||||||
@ -61,7 +61,7 @@ class NeighborController:
|
|||||||
async def get_route(
|
async def get_route(
|
||||||
self,
|
self,
|
||||||
address: cptypes.IPAddress,
|
address: cptypes.IPAddress,
|
||||||
) -> typing.Optional[pyroute2.iproute.linux.rtmsg]:
|
) -> pyroute2.iproute.linux.rtmsg | None:
|
||||||
try:
|
try:
|
||||||
return self.ip.route('get', dst=str(address))[0]
|
return self.ip.route('get', dst=str(address))[0]
|
||||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||||
@ -72,7 +72,7 @@ class NeighborController:
|
|||||||
async def dump_neighbors(
|
async def dump_neighbors(
|
||||||
self,
|
self,
|
||||||
interface: str,
|
interface: str,
|
||||||
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
||||||
ifindex = socket.if_nametoindex(interface)
|
ifindex = socket.if_nametoindex(interface)
|
||||||
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
|
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
|
||||||
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
||||||
|
@ -12,7 +12,7 @@ from .nft_socket import NFTSocket
|
|||||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
|
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
|
||||||
|
|
||||||
|
|
||||||
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
|
def _from_msec(msecs: int | None) -> float | None:
|
||||||
# to seconds
|
# to seconds
|
||||||
if msecs is None:
|
if msecs is None:
|
||||||
return None
|
return None
|
||||||
@ -27,7 +27,7 @@ class NftSet:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_elem(
|
def _set_elem(
|
||||||
mac: cptypes.MacAddress,
|
mac: cptypes.MacAddress,
|
||||||
timeout: typing.Optional[typing.Union[int, float]] = None,
|
timeout: int | float | None = None,
|
||||||
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||||
attrs: dict[str, typing.Any] = {
|
attrs: dict[str, typing.Any] = {
|
||||||
'NFTA_SET_ELEM_KEY': dict(
|
'NFTA_SET_ELEM_KEY': dict(
|
||||||
@ -40,7 +40,7 @@ class NftSet:
|
|||||||
|
|
||||||
def _bulk_insert(
|
def _bulk_insert(
|
||||||
self,
|
self,
|
||||||
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
|
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
|
||||||
) -> None:
|
) -> None:
|
||||||
ser_entries = [
|
ser_entries = [
|
||||||
self._set_elem(mac)
|
self._set_elem(mac)
|
||||||
@ -85,13 +85,13 @@ class NftSet:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
def bulk_insert(self, entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]]) -> None:
|
||||||
# limit chunk size
|
# limit chunk size
|
||||||
while len(entries) > 0:
|
while len(entries) > 0:
|
||||||
self._bulk_insert(entries[:1024])
|
self._bulk_insert(entries[:1024])
|
||||||
entries = entries[1024:]
|
entries = entries[1024:]
|
||||||
|
|
||||||
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
|
def insert(self, mac: cptypes.MacAddress, timeout: int | float) -> None:
|
||||||
self.bulk_insert([(mac, timeout)])
|
self.bulk_insert([(mac, timeout)])
|
||||||
|
|
||||||
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||||
|
@ -37,7 +37,7 @@ def _monkey_patch_pyroute2():
|
|||||||
self['header'].update(header)
|
self['header'].update(header)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def overwrite_methods(cls: typing.Type) -> None:
|
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
|
||||||
if cls.setvalue is _orig_setvalue:
|
if cls.setvalue is _orig_setvalue:
|
||||||
cls.setvalue = _nlmsg_base__setvalue
|
cls.setvalue = _nlmsg_base__setvalue
|
||||||
for subcls in cls.__subclasses__():
|
for subcls in cls.__subclasses__():
|
||||||
@ -49,7 +49,7 @@ def _monkey_patch_pyroute2():
|
|||||||
_monkey_patch_pyroute2()
|
_monkey_patch_pyroute2()
|
||||||
|
|
||||||
|
|
||||||
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
||||||
msg = msg_class()
|
msg = msg_class()
|
||||||
for key, value in header.items():
|
for key, value in header.items():
|
||||||
msg['header'][key] = value
|
msg['header'][key] = value
|
||||||
@ -79,7 +79,7 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict
|
|||||||
|
|
||||||
|
|
||||||
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
||||||
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
policy: dict[int, type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
||||||
@ -110,7 +110,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
|||||||
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
||||||
|
|
||||||
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
msg_class: type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
||||||
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
||||||
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
|
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
|
||||||
msg = _build(msg_class, attrs=attrs, **fields)
|
msg = _build(msg_class, attrs=attrs, **fields)
|
||||||
@ -178,7 +178,7 @@ class NFTTransaction:
|
|||||||
self._msgs.append(msg)
|
self._msgs.append(msg)
|
||||||
|
|
||||||
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
msg_class: type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
||||||
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
|
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
|
||||||
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
|
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
|
||||||
header = dict(
|
header = dict(
|
||||||
|
@ -19,7 +19,7 @@ def _check_watchdog_pid() -> bool:
|
|||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
||||||
target = os.environ.pop('NOTIFY_SOCKET', None)
|
target = os.environ.pop('NOTIFY_SOCKET', None)
|
||||||
ns: typing.Optional[trio.socket.SocketType] = None
|
ns: trio.socket.SocketType | None = None
|
||||||
watchdog_usec: int = 0
|
watchdog_usec: int = 0
|
||||||
if target:
|
if target:
|
||||||
if target.startswith('@'):
|
if target.startswith('@'):
|
||||||
@ -44,7 +44,7 @@ async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
|||||||
|
|
||||||
|
|
||||||
class SdNotify:
|
class SdNotify:
|
||||||
def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None:
|
def __init__(self, *, _ns: trio.socket.SocketType | None) -> None:
|
||||||
self._ns = _ns
|
self._ns = _ns
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
|
||||||
import zoneinfo
|
import zoneinfo
|
||||||
|
|
||||||
|
|
||||||
_zoneinfo: typing.Optional[zoneinfo.ZoneInfo] = None
|
_zoneinfo: zoneinfo.ZoneInfo | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_local_timezone():
|
def get_local_timezone():
|
||||||
|
Loading…
Reference in New Issue
Block a user