From 0f45f892118b3a4c89b6058ee4ad3ecedfcf6ada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Wed, 15 Nov 2023 10:02:28 +0100 Subject: [PATCH] format with black (mostly quotes) --- pyproject.toml | 7 +++- src/capport/api/app.py | 10 ++--- src/capport/api/app_cls.py | 10 ++--- src/capport/api/hypercorn_run.py | 8 ++-- src/capport/api/lang.py | 47 +++++++++++----------- src/capport/api/proxy_fix.py | 16 ++++---- src/capport/api/setup.py | 4 +- src/capport/api/views.py | 60 ++++++++++++++-------------- src/capport/comm/hub.py | 27 +++++++------ src/capport/comm/message.py | 7 ++-- src/capport/comm/message.pyi | 42 +++++++------------- src/capport/config.py | 22 +++++------ src/capport/control/run.py | 6 +-- src/capport/cptypes.py | 26 ++++++------- src/capport/database.py | 33 +++++++--------- src/capport/stats.py | 55 +++++++++++++------------- src/capport/utils/cli.py | 6 +-- src/capport/utils/ipneigh.py | 18 ++++----- src/capport/utils/nft_set.py | 67 ++++++++++++++------------------ src/capport/utils/nft_socket.py | 24 ++++++------ src/capport/utils/sd_notify.py | 18 ++++----- src/capport/utils/zoneinfo.py | 4 +- 22 files changed, 245 insertions(+), 272 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd74a36..b53bc37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,14 @@ requires = [ build-backend = "setuptools.build_meta" [tool.mypy] -python_version = "3.9" +python_version = "3.11" # warn_return_any = true warn_unused_configs = true exclude = [ '_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary) ] + +[tool.black] +line-length = 120 +target-version = ['py311'] +exclude = '_pb2.py' diff --git a/src/capport/api/app.py b/src/capport/api/app.py index 9b78a1b..2e28287 100644 --- a/src/capport/api/app.py +++ b/src/capport/api/app.py @@ -4,8 +4,8 @@ from .app_cls import MyQuartApp app = MyQuartApp(__name__) -__import__('capport.api.setup') -__import__('capport.api.proxy_fix') -__import__('capport.api.lang') -__import__('capport.api.template_filters') -__import__('capport.api.views') +__import__("capport.api.setup") +__import__("capport.api.proxy_fix") +__import__("capport.api.lang") +__import__("capport.api.template_filters") +__import__("capport.api.views") diff --git a/src/capport/api/app_cls.py b/src/capport/api/app_cls.py index be89589..cedc52b 100644 --- a/src/capport/api/app_cls.py +++ b/src/capport/api/app_cls.py @@ -49,16 +49,16 @@ class MyQuartApp(quart_trio.QuartTrio): def __init__(self, import_name: str, **kwargs) -> None: self.my_config = capport.config.Config.load_default_once() - kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates')) - cust_templ = os.path.join('custom', 'templates') + kwargs.setdefault("template_folder", os.path.join(os.path.dirname(__file__), "templates")) + cust_templ = os.path.join("custom", "templates") if os.path.exists(cust_templ): self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ)) - cust_static = os.path.abspath(os.path.join('custom', 'static')) + cust_static = os.path.abspath(os.path.join("custom", "static")) if os.path.exists(cust_static): static_folder = cust_static else: - static_folder = os.path.join(os.path.dirname(__file__), 'static') - kwargs.setdefault('static_folder', static_folder) + static_folder = os.path.join(os.path.dirname(__file__), "static") + kwargs.setdefault("static_folder", static_folder) super().__init__(import_name, **kwargs) self.debug = self.my_config.debug self.secret_key = self.my_config.cookie_secret diff --git a/src/capport/api/hypercorn_run.py b/src/capport/api/hypercorn_run.py index c9da53c..aaa4f18 100644 --- a/src/capport/api/hypercorn_run.py +++ b/src/capport/api/hypercorn_run.py @@ -12,7 +12,7 @@ import capport.config def run(config: hypercorn.config.Config) -> None: sockets = config.create_sockets() - assert config.worker_class == 'trio' + assert config.worker_class == "trio" hypercorn.trio.run.trio_worker(config=config, sockets=sockets) @@ -28,7 +28,7 @@ class CliArguments: def __init__(self): parser = argparse.ArgumentParser() - parser.add_argument('--config', '-c') + parser.add_argument("--config", "-c") args = parser.parse_args() self.config = args.config @@ -39,8 +39,8 @@ def main() -> None: _config = capport.config.Config.load_default_once(filename=args.config) hypercorn_config = hypercorn.config.Config() - hypercorn_config.application_path = 'capport.api.app' - hypercorn_config.worker_class = 'trio' + hypercorn_config.application_path = "capport.api.app" + hypercorn_config.worker_class = "trio" hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"] if _config.server_names: diff --git a/src/capport/api/lang.py b/src/capport/api/lang.py index 86e7689..eca8106 100644 --- a/src/capport/api/lang.py +++ b/src/capport/api/lang.py @@ -7,22 +7,23 @@ import quart from .app import app -_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$') +_VALID_LANGUAGE_NAMES = re.compile(r"^[-a-z0-9_]+$") def parse_accept_language(value: str) -> list[str]: value = value.strip() - if not value or value == '*': + if not value or value == "*": return [] tuples = [] - for entry in value.split(','): - attrs = entry.split(';') + for entry in value.split(","): + attrs = entry.split(";") name = attrs.pop(0).strip().lower() q = 1.0 for attr in attrs: - if not '=' in attr: continue - key, value = attr.split('=', maxsplit=1) - if key.strip().lower() == 'q': + if not "=" in attr: + continue + key, value = attr.split("=", maxsplit=1) + if key.strip().lower() == "q": try: q = float(value.strip()) except ValueError: @@ -32,14 +33,17 @@ def parse_accept_language(value: str) -> list[str]: tuples.sort() have = set() result = [] - for (_q, name) in tuples: - if name in have: continue - if name == '*': break + for _q, name in tuples: + if name in have: + continue + if name == "*": + break have.add(name) if _VALID_LANGUAGE_NAMES.match(name): result.append(name) - short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0] - if not short_name or short_name in have: continue + short_name = name.split("-", maxsplit=1)[0].split("_", maxsplit=1)[0] + if not short_name or short_name in have: + continue have.add(short_name) result.append(short_name) return result @@ -50,23 +54,23 @@ def detect_language(): g = quart.g r = quart.request s = quart.session - if 'setlang' in r.args: - lang = r.args.get('setlang').strip().lower() + if "setlang" in r.args: + lang = r.args.get("setlang").strip().lower() if lang and _VALID_LANGUAGE_NAMES.match(lang): - if s.get('lang') != lang: - s['lang'] = lang + if s.get("lang") != lang: + s["lang"] = lang g.langs = [lang] return else: # reset language - s.pop('lang', None) - lang = s.get('lang') + s.pop("lang", None) + lang = s.get("lang") if lang: lang = lang.strip().lower() if lang and _VALID_LANGUAGE_NAMES.match(lang): g.langs = [lang] return - acc_lang = ','.join(r.headers.get_all('Accept-Language')) + acc_lang = ",".join(r.headers.get_all("Accept-Language")) g.langs = parse_accept_language(acc_lang) @@ -74,9 +78,6 @@ async def render_i18n_template(template, /, **kwargs) -> str: langs: list[str] = quart.g.langs if not langs: return await quart.render_template(template, **kwargs) - names = [ - os.path.join('i18n', lang, template) - for lang in langs - ] + names = [os.path.join("i18n", lang, template) for lang in langs] names.append(template) return await quart.render_template(names, **kwargs) diff --git a/src/capport/api/proxy_fix.py b/src/capport/api/proxy_fix.py index 303b2e6..4f14a05 100644 --- a/src/capport/api/proxy_fix.py +++ b/src/capport/api/proxy_fix.py @@ -31,28 +31,28 @@ def local_proxy_fix(request: quart.Request): return if not addr.is_loopback: return - client = _get_first_in_list(request.headers.get('X-Forwarded-For')) + client = _get_first_in_list(request.headers.get("X-Forwarded-For")) if not client: # assume this is always set behind reverse proxies supporting any of the headers return request.remote_addr = client - scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https')) + scheme = _get_first_in_list(request.headers.get("X-Forwarded-Proto"), ("http", "https")) port: int | None = None if scheme: - port = 443 if scheme == 'https' else 80 + port = 443 if scheme == "https" else 80 request.scheme = scheme - host = _get_first_in_list(request.headers.get('X-Forwarded-Host')) + host = _get_first_in_list(request.headers.get("X-Forwarded-Host")) port_s: str | None if host: request.host = host - if ':' in host and not host.endswith(']'): + if ":" in host and not host.endswith("]"): try: - _, port_s = host.rsplit(':', maxsplit=1) + _, port_s = host.rsplit(":", maxsplit=1) port = int(port_s) except ValueError: # ignore invalid port in host header pass - port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port')) + port_s = _get_first_in_list(request.headers.get("X-Forwarded-Port")) if port_s: try: port = int(port_s) @@ -62,7 +62,7 @@ def local_proxy_fix(request: quart.Request): if port: if request.server and len(request.server) == 2: request.server = (request.server[0], port) - root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix')) + root_path = _get_first_in_list(request.headers.get("X-Forwarded-Prefix")) if root_path: request.root_path = root_path diff --git a/src/capport/api/setup.py b/src/capport/api/setup.py index 898984f..76ab1bd 100644 --- a/src/capport/api/setup.py +++ b/src/capport/api/setup.py @@ -47,10 +47,10 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: async def _setup(*, task_status=trio.TASK_STATUS_IGNORED): async with open_sdnotify() as sn: - await sn.send('STATUS=Starting hub') + await sn.send("STATUS=Starting hub") async with trio.open_nursery() as nursery: await nursery.start(_run_hub) - await sn.send('READY=1', 'STATUS=Ready for client requests') + await sn.send("READY=1", "STATUS=Ready for client requests") task_status.started() # continue running hub and systemd watchdog handler diff --git a/src/capport/api/views.py b/src/capport/api/views.py index f1d10f5..706e0bd 100644 --- a/src/capport/api/views.py +++ b/src/capport/api/views.py @@ -23,12 +23,12 @@ _logger = logging.getLogger(__name__) def get_client_ip() -> cptypes.IPAddress: remote_addr = quart.request.remote_addr if not remote_addr: - quart.abort(500, 'Missing client address') + quart.abort(500, "Missing client address") try: addr = ipaddress.ip_address(remote_addr) except ValueError as e: - _logger.warning(f'Invalid client address {remote_addr!r}: {e}') - quart.abort(500, 'Invalid client address') + _logger.warning(f"Invalid client address {remote_addr!r}: {e}") + quart.abort(500, "Invalid client address") return addr @@ -45,7 +45,7 @@ async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.Ma mac = await get_client_mac_if_present(address) if mac is None: _logger.warning(f"Couldn't find MAC addresss for {address}") - quart.abort(404, 'Unknown client') + quart.abort(404, "Unknown client") return mac @@ -58,7 +58,7 @@ async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> Non quart.abort(500, str(e)) if pu: - _logger.debug(f'User {mac} (with IP {address}) logged in') + _logger.debug(f"User {mac} (with IP {address}) logged in") for msg in pu.serialized: await app.my_hub.broadcast(msg) @@ -71,7 +71,7 @@ async def user_logout(mac: cptypes.MacAddress) -> None: except capport.database.NotReadyYet as e: quart.abort(500, str(e)) if pu: - _logger.debug(f'User {mac} logged out') + _logger.debug(f"User {mac} logged out") for msg in pu.serialized: await app.my_hub.broadcast(msg) @@ -92,73 +92,73 @@ async def user_lookup() -> cptypes.MacPublicState: def check_self_origin(): - origin = quart.request.headers.get('Origin', None) + origin = quart.request.headers.get("Origin", None) if origin is None: # not a request by a modern browser - probably curl or something similar. don't care. return origin = origin.lower().strip() - if origin == 'none': - quart.abort(403, 'Origin is none') - origin_parts = origin.split('/') + if origin == "none": + quart.abort(403, "Origin is none") + origin_parts = origin.split("/") # Origin should look like: :// (optionally followed by :) if len(origin_parts) < 3: - quart.abort(400, 'Broken Origin header') - if origin_parts[0] != 'https:' and not app.my_config.debug: + quart.abort(400, "Broken Origin header") + if origin_parts[0] != "https:" and not app.my_config.debug: # -> require https in production - quart.abort(403, 'Non-https Origin not allowed') + quart.abort(403, "Non-https Origin not allowed") origin_host = origin_parts[2] - host = quart.request.headers.get('Host', None) + host = quart.request.headers.get("Host", None) if host is None: - quart.abort(403, 'Missing Host header') + quart.abort(403, "Missing Host header") if host.lower() != origin_host: - quart.abort(403, 'Origin mismatch') + quart.abort(403, "Origin mismatch") -@app.route('/', methods=['GET']) +@app.route("/", methods=["GET"]) async def index(missing_accept: bool = False): state = await user_lookup() if not state.mac: - return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept) + return await render_i18n_template("index_unknown.html", state=state, missing_accept=missing_accept) elif state.allowed: - return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept) + return await render_i18n_template("index_active.html", state=state, missing_accept=missing_accept) else: - return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept) + return await render_i18n_template("index_inactive.html", state=state, missing_accept=missing_accept) -@app.route('/login', methods=['POST']) +@app.route("/login", methods=["POST"]) async def login(): check_self_origin() with trio.fail_after(5.0): form = await quart.request.form - if form.get('accept') != '1': + if form.get("accept") != "1": return await index(missing_accept=True) - req_mac = form.get('mac') + req_mac = form.get("mac") if not req_mac: - quart.abort(400, description='Missing MAC in request form data') + quart.abort(400, description="Missing MAC in request form data") address = get_client_ip() mac = await get_client_mac(address) if str(mac) != req_mac: quart.abort(403, description="Passed MAC in request form doesn't match client address") await user_login(address, mac) - return quart.redirect('/', code=303) + return quart.redirect("/", code=303) -@app.route('/logout', methods=['POST']) +@app.route("/logout", methods=["POST"]) async def logout(): check_self_origin() with trio.fail_after(5.0): form = await quart.request.form - req_mac = form.get('mac') + req_mac = form.get("mac") if not req_mac: - quart.abort(400, description='Missing MAC in request form data') + quart.abort(400, description="Missing MAC in request form data") mac = await get_client_mac() if str(mac) != req_mac: quart.abort(403, description="Passed MAC in request form doesn't match client address") await user_logout(mac) - return quart.redirect('/', code=303) + return quart.redirect("/", code=303) -@app.route('/api/captive-portal', methods=['GET']) +@app.route("/api/captive-portal", methods=["GET"]) # RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908 async def captive_api(): state = await user_lookup() diff --git a/src/capport/comm/hub.py b/src/capport/comm/hub.py index 8cd3216..f1eaf74 100644 --- a/src/capport/comm/hub.py +++ b/src/capport/comm/hub.py @@ -43,7 +43,7 @@ class Channel: _logger.debug(f"{self}: created (server_side={server_side})") def __repr__(self) -> str: - return f'Channel[0x{id(self):x}]' + return f"Channel[0x{id(self):x}]" async def do_handshake(self) -> capport.comm.message.Hello: try: @@ -59,9 +59,8 @@ class Channel: peer_hello = (await self.recv_msg()).to_variant() if not isinstance(peer_hello, capport.comm.message.Hello): raise HubConnectionReadError("Expected Hello as first message") - auth_succ = \ - (peer_hello.authentication == - self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)) + expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside) + auth_succ = peer_hello.authentication == expected_auth await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) peer_auth = (await self.recv_msg()).to_variant() if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): @@ -72,7 +71,7 @@ class Channel: async def _read(self, num: int) -> bytes: assert num > 0 - buf = b'' + buf = b"" # _logger.debug(f"{self}:_read({num})") while num > 0: try: @@ -92,7 +91,7 @@ class Channel: async def _recv_raw_msg(self) -> bytes: len_bytes = await self._read(4) - chunk_size, = struct.unpack('!I', len_bytes) + (chunk_size,) = struct.unpack("!I", len_bytes) chunk = await self._read(chunk_size) if chunk is None: raise HubConnectionReadError("Unexpected end of TLS stream after chunk length") @@ -116,7 +115,7 @@ class Channel: async def send_msg(self, msg: capport.comm.message.Message): chunk = msg.SerializeToString(deterministic=True) chunk_size = len(chunk) - len_bytes = struct.pack('!I', chunk_size) + len_bytes = struct.pack("!I", chunk_size) chunk = len_bytes + chunk await self._send_raw(chunk) @@ -158,7 +157,7 @@ class Connection: if msg: await self._channel.send_msg(msg) else: - await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message()) + await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message()) except trio.TooSlowError: _logger.warning(f"{self._channel}: send timed out") except ConnectionError as e: @@ -303,7 +302,7 @@ class Hub: if is_controller: state_filename = config.database_file else: - state_filename = '' + state_filename = "" self.database = capport.database.Database(state_filename=state_filename) self._anon_context = ssl.SSLContext() # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon @@ -311,14 +310,14 @@ class Hub: self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2 # -> AECDH-AES256-SHA # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? - self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0') + self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0") self._controllers: dict[str, ControllerConn] = {} self._established: dict[uuid.UUID, Connection] = {} async def _accept(self, stream): remotename = stream.socket.getpeername() if isinstance(remotename, tuple) and len(remotename) == 2: - remote = f'[{remotename[0]}]:{remotename[1]}' + remote = f"[{remotename[0]}]:{remotename[1]}" else: remote = str(remotename) try: @@ -351,11 +350,11 @@ class Hub: await trio.sleep_forever() def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes: - m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256) + m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256) if server_side: - m.update(b'server$') + m.update(b"server$") else: - m.update(b'client$') + m.update(b"client$") m.update(ssl_binding) return m.digest() diff --git a/src/capport/comm/message.py b/src/capport/comm/message.py index c1eabd2..40ef23e 100644 --- a/src/capport/comm/message.py +++ b/src/capport/comm/message.py @@ -6,7 +6,7 @@ from .protobuf import message_pb2 def _message_to_variant(self: message_pb2.Message) -> typing.Any: - variant_name = self.WhichOneof('oneof') + variant_name = self.WhichOneof("oneof") if variant_name: return getattr(self, variant_name) return None @@ -16,14 +16,15 @@ def _make_to_message(oneof_field): def to_message(self) -> message_pb2.Message: msg = message_pb2.Message(**{oneof_field: self}) return msg + return to_message def _monkey_patch(): g = globals() - g['Message'] = message_pb2.Message + g["Message"] = message_pb2.Message message_pb2.Message.to_variant = _message_to_variant - for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields: + for field in message_pb2._MESSAGE.oneofs_by_name["oneof"].fields: type_name = field.message_type.name field_type = getattr(message_pb2, type_name) field_type.to_message = _make_to_message(field.name) diff --git a/src/capport/comm/message.pyi b/src/capport/comm/message.pyi index aa2c61f..ada2975 100644 --- a/src/capport/comm/message.pyi +++ b/src/capport/comm/message.pyi @@ -1,10 +1,8 @@ import google.protobuf.message import typing - # manually maintained typehints for protobuf created (and monkey-patched) types - class Message(google.protobuf.message.Message): hello: Hello authentication_result: AuthenticationResult @@ -14,15 +12,13 @@ class Message(google.protobuf.message.Message): def __init__( self, *, - hello: Hello | None=None, - authentication_result: AuthenticationResult | None=None, - ping: Ping | None=None, - mac_states: MacStates | None=None, + hello: Hello | None = None, + authentication_result: AuthenticationResult | None = None, + ping: Ping | None = None, + mac_states: MacStates | None = None, ) -> None: ... - def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ... - class Hello(google.protobuf.message.Message): instance_id: bytes hostname: str @@ -32,51 +28,43 @@ class Hello(google.protobuf.message.Message): def __init__( self, *, - instance_id: bytes=b'', - hostname: str='', - is_controller: bool=False, - authentication: bytes=b'', + instance_id: bytes = b"", + hostname: str = "", + is_controller: bool = False, + authentication: bytes = b"", ) -> None: ... - def to_message(self) -> Message: ... - class AuthenticationResult(google.protobuf.message.Message): success: bool def __init__( self, *, - success: bool=False, + success: bool = False, ) -> None: ... - def to_message(self) -> Message: ... - class Ping(google.protobuf.message.Message): payload: bytes def __init__( self, *, - payload: bytes=b'', + payload: bytes = b"", ) -> None: ... - def to_message(self) -> Message: ... - class MacStates(google.protobuf.message.Message): states: list[MacState] def __init__( self, *, - states: list[MacState]=[], + states: list[MacState] = [], ) -> None: ... - def to_message(self) -> Message: ... - class MacState(google.protobuf.message.Message): mac_address: bytes last_change: int # Seconds of UTC time since epoch @@ -86,8 +74,8 @@ class MacState(google.protobuf.message.Message): def __init__( self, *, - mac_address: bytes=b'', - last_change: int=0, - allow_until: int=0, - allowed: bool=False, + mac_address: bytes = b"", + last_change: int = 0, + allow_until: int = 0, + allowed: bool = False, ) -> None: ... diff --git a/src/capport/config.py b/src/capport/config.py index ceeb3d2..dc6be3c 100644 --- a/src/capport/config.py +++ b/src/capport/config.py @@ -39,7 +39,7 @@ class Config: @staticmethod def load(filename: str | None = None) -> Config: if filename is None: - for name in ('capport.yaml', '/etc/capport.yaml'): + for name in ("capport.yaml", "/etc/capport.yaml"): if os.path.exists(name): return Config.load(name) raise RuntimeError("Missing config file") @@ -47,19 +47,19 @@ class Config: data = yaml.safe_load(f) if not isinstance(data, dict): raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}") - controllers = list(map(str, data.pop('controllers'))) + controllers = list(map(str, data.pop("controllers"))) config = Config( _source_filename=filename, controllers=controllers, - server_names=data.pop('server-names', []), - comm_secret=str(data.pop('comm-secret')), - cookie_secret=str(data.pop('cookie-secret')), - venue_info_url=str(data.pop('venue-info-url')), - session_timeout=data.pop('session-timeout', 3600), - api_port=data.pop('api-port', 8000), - controller_port=data.pop('controller-port', 5000), - database_file=str(data.pop('database-file', 'capport.state')), - debug=data.pop('debug', False) + server_names=data.pop("server-names", []), + comm_secret=str(data.pop("comm-secret")), + cookie_secret=str(data.pop("cookie-secret")), + venue_info_url=str(data.pop("venue-info-url")), + session_timeout=data.pop("session-timeout", 3600), + api_port=data.pop("api-port", 8000), + controller_port=data.pop("controller-port", 5000), + database_file=str(data.pop("database-file", "capport.state")), + debug=data.pop("debug", False), ) if data: _logger.error(f"Unknown config elements: {list(data.keys())}") diff --git a/src/capport/control/run.py b/src/capport/control/run.py index 190a6c8..b7c8a38 100644 --- a/src/capport/control/run.py +++ b/src/capport/control/run.py @@ -58,9 +58,9 @@ async def amain(config: capport.config.Config) -> None: async with trio.open_nursery() as nursery: # hub.run loads the statefile from disk before signalling it was "started" await nursery.start(hub.run) - await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set') + await sn.send("READY=1", "STATUS=Deploying initial entries to nftables set") app.apply_db_entries(hub.database.entries()) - await sn.send('STATUS=Kernel fully synchronized') + await sn.send("STATUS=Kernel fully synchronized") @dataclasses.dataclass @@ -69,7 +69,7 @@ class CliArguments: def __init__(self): parser = argparse.ArgumentParser() - parser.add_argument('--config', '-c') + parser.add_argument("--config", "-c") args = parser.parse_args() self.config = args.config diff --git a/src/capport/cptypes.py b/src/capport/cptypes.py index f725db3..51ebe63 100644 --- a/src/capport/cptypes.py +++ b/src/capport/cptypes.py @@ -23,14 +23,14 @@ class MacAddress: raw: bytes def __str__(self) -> str: - return self.raw.hex(':') + return self.raw.hex(":") def __repr__(self) -> str: return repr(str(self)) @staticmethod def parse(s: str) -> MacAddress: - return MacAddress(bytes.fromhex(s.replace(':', ''))) + return MacAddress(bytes.fromhex(s.replace(":", ""))) @dataclasses.dataclass(frozen=True, order=True) @@ -40,9 +40,9 @@ class Timestamp: def __str__(self) -> str: try: ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc) - return ts.isoformat(sep=' ') + return ts.isoformat(sep=" ") except OSError: - return f'epoch@{self.epoch}' + return f"epoch@{self.epoch}" def __repr__(self) -> str: return repr(str(self)) @@ -85,7 +85,7 @@ class MacPublicState: def allowed_remaining_duration(self) -> str: mm, ss = divmod(self.allowed_remaining, 60) hh, mm = divmod(mm, 60) - return f'{hh}:{mm:02}:{ss:02}' + return f"{hh}:{mm:02}:{ss:02}" @property def allowed_until(self) -> datetime.datetime | None: @@ -95,18 +95,18 @@ class MacPublicState: def to_rfc8908(self, config: Config) -> quart.Response: response: dict[str, typing.Any] = { - 'user-portal-url': quart.url_for('index', _external=True), + "user-portal-url": quart.url_for("index", _external=True), } if config.venue_info_url: - response['venue-info-url'] = config.venue_info_url + response["venue-info-url"] = config.venue_info_url if self.captive: - response['captive'] = True + response["captive"] = True else: - response['captive'] = False - response['seconds-remaining'] = self.allowed_remaining - response['can-extend-session'] = True + response["captive"] = False + response["seconds-remaining"] = self.allowed_remaining + response["can-extend-session"] = True return quart.Response( json.dumps(response), - headers={'Cache-Control': 'private'}, - content_type='application/captive+json', + headers={"Cache-Control": "private"}, + content_type="application/captive+json", ) diff --git a/src/capport/database.py b/src/capport/database.py index ec4ebb0..b2300eb 100644 --- a/src/capport/database.py +++ b/src/capport/database.py @@ -128,7 +128,7 @@ def _serialize_mac_states_as_messages( def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes: chunk = states.SerializeToString(deterministic=True) chunk_size = len(chunk) - len_bytes = struct.pack('!I', chunk_size) + len_bytes = struct.pack("!I", chunk_size) return len_bytes + chunk @@ -204,39 +204,32 @@ class Database: return _serialize_mac_states_as_messages(self._macs) def as_json(self) -> dict: - return { - str(addr): entry.as_json() - for addr, entry in self._macs.items() - } + return {str(addr): entry.as_json() for addr, entry in self._macs.items()} async def _run_statefile(self) -> None: - rx: trio.MemoryReceiveChannel[ - capport.comm.message.MacStates | list[capport.comm.message.MacStates], - ] - tx: trio.MemorySendChannel[ - capport.comm.message.MacStates | list[capport.comm.message.MacStates], - ] + rx: trio.MemoryReceiveChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]] + tx: trio.MemorySendChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]] tx, rx = trio.open_memory_channel(64) self._send_changes = tx assert self._state_filename filename: str = self._state_filename - tmp_filename = f'{filename}.new-{os.getpid()}' + tmp_filename = f"{filename}.new-{os.getpid()}" async def resync(all_states: list[capport.comm.message.MacStates]): try: - async with await trio.open_file(tmp_filename, 'xb') as tf: + async with await trio.open_file(tmp_filename, "xb") as tf: for states in all_states: await tf.write(_states_to_chunk(states)) os.rename(tmp_filename, filename) finally: if os.path.exists(tmp_filename): - _logger.warning(f'Removing (failed) state export file {tmp_filename}') + _logger.warning(f"Removing (failed) state export file {tmp_filename}") os.unlink(tmp_filename) try: while True: - async with await trio.open_file(filename, 'ab', buffering=0) as sf: + async with await trio.open_file(filename, "ab", buffering=0) as sf: while True: update = await rx.receive() if isinstance(update, list): @@ -247,10 +240,10 @@ class Database: await resync(update) # now reopen normal statefile and continue appending updates except trio.Cancelled: - _logger.info('Final sync to disk') + _logger.info("Final sync to disk") with trio.CancelScope(shield=True): await resync(_serialize_mac_states(self._macs)) - _logger.info('Final sync to disk done') + _logger.info("Final sync to disk done") async def _load_statefile(self): if not os.path.exists(self._state_filename): @@ -259,7 +252,7 @@ class Database: # we're going to ignore changes from loading the file pu = PendingUpdates(self) pu._closed = False - async with await trio.open_file(self._state_filename, 'rb') as sf: + async with await trio.open_file(self._state_filename, "rb") as sf: while True: try: len_bytes = await sf.read(4) @@ -268,7 +261,7 @@ class Database: if len(len_bytes) < 4: _logger.error("Failed to read next chunk from statefile (unexpected EOF)") return - chunk_size, = struct.unpack('!I', len_bytes) + (chunk_size,) = struct.unpack("!I", len_bytes) chunk = await sf.read(chunk_size) except IOError as e: _logger.error(f"Failed to read next chunk from statefile: {e}") @@ -286,7 +279,7 @@ class Database: except Exception as e: errors += 1 if errors < 5: - _logger.error(f'Failed to handle state: {e}') + _logger.error(f"Failed to handle state: {e}") def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: entry = self._macs.get(mac) diff --git a/src/capport/stats.py b/src/capport/stats.py index b5f79be..b7d396f 100644 --- a/src/capport/stats.py +++ b/src/capport/stats.py @@ -14,20 +14,17 @@ from . import cptypes def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None): # no labels in our names for now, always print help and type if help: - print(f'# HELP {name} {help}') - print(f'# TYPE {name} {mtype}') + print(f"# HELP {name} {help}") + print(f"# TYPE {name} {mtype}") if now: - print(f'{name} {value} {now}') + print(f"{name} {value} {now}") else: - print(f'{name} {value}') + print(f"{name} {value}") async def amain(client_ifname: str): ns = capport.utils.nft_set.NftSet() - captive_allowed_entries: set[cptypes.MacAddress] = { - entry['mac'] - for entry in ns.list() - } + captive_allowed_entries: set[cptypes.MacAddress] = {entry["mac"] for entry in ns.list()} seen_allowed_entries: set[cptypes.MacAddress] = set() total_ipv4 = 0 total_ipv6 = 0 @@ -47,46 +44,46 @@ async def amain(client_ifname: str): total_ipv6 += 1 unique_ipv6.add(mac) print_metric( - 'capport_allowed_macs', - 'gauge', + "capport_allowed_macs", + "gauge", len(captive_allowed_entries), - help='Number of allowed client mac addresses', + help="Number of allowed client mac addresses", ) print_metric( - 'capport_allowed_neigh_macs', - 'gauge', + "capport_allowed_neigh_macs", + "gauge", len(seen_allowed_entries), - help='Number of allowed client mac addresses seen in neighbor cache', + help="Number of allowed client mac addresses seen in neighbor cache", ) print_metric( - 'capport_unique', - 'gauge', + "capport_unique", + "gauge", len(unique_clients), - help='Number of clients (mac addresses) in client network seen in neighbor cache', + help="Number of clients (mac addresses) in client network seen in neighbor cache", ) print_metric( - 'capport_unique_ipv4', - 'gauge', + "capport_unique_ipv4", + "gauge", len(unique_ipv4), - help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache', + help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache", ) print_metric( - 'capport_unique_ipv6', - 'gauge', + "capport_unique_ipv6", + "gauge", len(unique_ipv6), - help='Number of IPv6 clients (unique per mac) in client network seen in neighbor cache', + help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache", ) print_metric( - 'capport_total_ipv4', - 'gauge', + "capport_total_ipv4", + "gauge", total_ipv4, - help='Number of IPv4 addresses seen in neighbor cache', + help="Number of IPv4 addresses seen in neighbor cache", ) print_metric( - 'capport_total_ipv6', - 'gauge', + "capport_total_ipv6", + "gauge", total_ipv6, - help='Number of IPv6 addresses seen in neighbor cache', + help="Number of IPv6 addresses seen in neighbor cache", ) diff --git a/src/capport/utils/cli.py b/src/capport/utils/cli.py index 4ed1fce..8a9ee07 100644 --- a/src/capport/utils/cli.py +++ b/src/capport/utils/cli.py @@ -10,8 +10,8 @@ def init_logger(config: capport.config.Config): if config.debug: loglevel = logging.DEBUG logging.basicConfig( - format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s', - datefmt='[%Y-%m-%d %H:%M:%S %z]', + format="%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s", + datefmt="[%Y-%m-%d %H:%M:%S %z]", level=loglevel, ) - logging.getLogger('hypercorn').propagate = False + logging.getLogger("hypercorn").propagate = False diff --git a/src/capport/utils/ipneigh.py b/src/capport/utils/ipneigh.py index 1edccc5..7bfc511 100644 --- a/src/capport/utils/ipneigh.py +++ b/src/capport/utils/ipneigh.py @@ -35,9 +35,9 @@ class NeighborController: route = await self.get_route(address) if route is None: return None - index = route.get_attr(route.name2nla('oif')) + index = route.get_attr(route.name2nla("oif")) try: - return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0] + return self.ip.neigh("get", dst=str(address), ifindex=index, state="none")[0] except pyroute2.netlink.exceptions.NetlinkError as e: if e.code == errno.ENOENT: return None @@ -53,7 +53,7 @@ class NeighborController: neigh = await self.get_neighbor(address, index=index, flags=flags) if neigh is None: return None - mac = neigh.get_attr(neigh.name2nla('lladdr')) + mac = neigh.get_attr(neigh.name2nla("lladdr")) if mac is None: return None return cptypes.MacAddress.parse(mac) @@ -63,7 +63,7 @@ class NeighborController: address: cptypes.IPAddress, ) -> pyroute2.iproute.linux.rtmsg | None: try: - return self.ip.route('get', dst=str(address))[0] + return self.ip.route("get", dst=str(address))[0] except pyroute2.netlink.exceptions.NetlinkError as e: if e.code == errno.ENOENT: return None @@ -74,16 +74,16 @@ class NeighborController: interface: str, ) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]: ifindex = socket.if_nametoindex(interface) - unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] + unicast_num = pyroute2.netlink.rtnl.rt_type["unicast"] # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) for family in (socket.AF_INET, socket.AF_INET6): - for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family): - if neigh['ndm_type'] != unicast_num: + for neigh in self.ip.neigh("dump", ifindex=ifindex, family=family): + if neigh["ndm_type"] != unicast_num: continue - mac = neigh.get_attr(neigh.name2nla('lladdr')) + mac = neigh.get_attr(neigh.name2nla("lladdr")) if not mac: continue - dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst'))) + dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla("dst"))) if dst.is_link_local: continue yield (cptypes.MacAddress.parse(mac), dst) diff --git a/src/capport/utils/nft_set.py b/src/capport/utils/nft_set.py index 86a6933..39d2627 100644 --- a/src/capport/utils/nft_set.py +++ b/src/capport/utils/nft_set.py @@ -30,26 +30,20 @@ class NftSet: timeout: int | float | None = None, ) -> _nftsocket.nft_set_elem_list_msg.set_elem: attrs: dict[str, typing.Any] = { - 'NFTA_SET_ELEM_KEY': dict( + "NFTA_SET_ELEM_KEY": dict( NFTA_DATA_VALUE=mac.raw, ), } if timeout: - attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout) + attrs["NFTA_SET_ELEM_TIMEOUT"] = int(1000 * timeout) return attrs def _bulk_insert( self, entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]], ) -> None: - ser_entries = [ - self._set_elem(mac) - for mac, _timeout in entries - ] - ser_entries_with_timeout = [ - self._set_elem(mac, timeout) - for mac, timeout in entries - ] + ser_entries = [self._set_elem(mac) for mac, _timeout in entries] + ser_entries_with_timeout = [self._set_elem(mac, timeout) for mac, timeout in entries] with self._socket.begin() as tx: # create doesn't affect existing elements, so: # make sure entries exists @@ -58,8 +52,8 @@ class NftSet: pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, ), ) @@ -68,8 +62,8 @@ class NftSet: _nftsocket.NFT_MSG_DELSETELEM, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, ), ) @@ -79,8 +73,8 @@ class NftSet: pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout, ), ) @@ -95,10 +89,7 @@ class NftSet: self.bulk_insert([(mac, timeout)]) def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None: - ser_entries = [ - self._set_elem(mac) - for mac in entries - ] + ser_entries = [self._set_elem(mac) for mac in entries] with self._socket.begin() as tx: # make sure entries exists tx.put( @@ -106,8 +97,8 @@ class NftSet: pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, ), ) @@ -116,8 +107,8 @@ class NftSet: _nftsocket.NFT_MSG_DELSETELEM, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, ), ) @@ -137,20 +128,20 @@ class NftSet: _nftsocket.NFT_MSG_GETSETELEM, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', - ) + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", + ), ) return [ { - 'mac': cptypes.MacAddress( - elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'), + "mac": cptypes.MacAddress( + elem.get_attr("NFTA_SET_ELEM_KEY").get_attr("NFTA_DATA_VALUE"), ), - 'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)), - 'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)), + "timeout": _from_msec(elem.get_attr("NFTA_SET_ELEM_TIMEOUT", None)), + "expiration": _from_msec(elem.get_attr("NFTA_SET_ELEM_EXPIRATION", None)), } for response in responses - for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', []) + for elem in response.get_attr("NFTA_SET_ELEM_LIST_ELEMENTS", []) ] def flush(self) -> None: @@ -158,9 +149,9 @@ class NftSet: _nftsocket.NFT_MSG_DELSETELEM, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_ELEM_LIST_TABLE='captive_mark', - NFTA_SET_ELEM_LIST_SET='allowed', - ) + NFTA_SET_ELEM_LIST_TABLE="captive_mark", + NFTA_SET_ELEM_LIST_SET="allowed", + ), ) def create(self): @@ -170,7 +161,7 @@ class NftSet: pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_TABLE_NAME='captive_mark', + NFTA_TABLE_NAME="captive_mark", ), ) tx.put( @@ -178,8 +169,8 @@ class NftSet: pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( - NFTA_SET_TABLE='captive_mark', - NFTA_SET_NAME='allowed', + NFTA_SET_TABLE="captive_mark", + NFTA_SET_NAME="allowed", NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft NFTA_SET_KEY_LEN=6, # length of key: mac address diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py index a45d788..3d1ea31 100644 --- a/src/capport/utils/nft_socket.py +++ b/src/capport/utils/nft_socket.py @@ -13,7 +13,7 @@ from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" -_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base) +_NlMsgBase = typing.TypeVar("_NlMsgBase", bound=pyroute2.netlink.nlmsg_base) # nft uses NESTED for those.. lets do the same @@ -25,16 +25,17 @@ _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM def _monkey_patch_pyroute2(): import pyroute2.netlink + # overwrite setdefault on nlmsg_base class hierarchy _orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue def _nlmsg_base__setvalue(self, value): - if not self.header or not self['header'] or not isinstance(value, dict): + if not self.header or not self["header"] or not isinstance(value, dict): return _orig_setvalue(self, value) # merge headers instead of overwriting them - header = value.pop('header', {}) + header = value.pop("header", {}) res = _orig_setvalue(self, value) - self['header'].update(header) + self["header"].update(header) return res def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None: @@ -52,20 +53,20 @@ _monkey_patch_pyroute2() def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: msg = msg_class() for key, value in header.items(): - msg['header'][key] = value + msg["header"][key] = value for key, value in fields.items(): msg[key] = value if attrs: - attr_list = msg['attrs'] + attr_list = msg["attrs"] r_nla_map = msg_class._nlmsg_base__r_nla_map for key, value in attrs.items(): if msg_class.prefix: key = msg_class.name2nla(key) prime = r_nla_map[key] - nla_class = prime['class'] + nla_class = prime["class"] if issubclass(nla_class, pyroute2.netlink.nla): # support passing nested attributes as dicts of subattributes (or lists of those) - if prime['nla_array']: + if prime["nla_array"]: value = [ _build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) @@ -83,10 +84,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): def __init__(self) -> None: super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) - policy = { - (x | (NFNL_SUBSYS_NFTABLES << 8)): y - for (x, y) in self.policy.items() - } + policy = {(x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items()} self.register_policy(policy) @contextlib.contextmanager @@ -156,7 +154,7 @@ class NFTTransaction: # no inner messages were queued... not sending anything return # request ACK on the last message (before END) - self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK + self._msgs[-1]["header"]["flags"] |= pyroute2.netlink.NLM_F_ACK self._msgs.append( # batch end message _build( diff --git a/src/capport/utils/sd_notify.py b/src/capport/utils/sd_notify.py index ee61b43..ec0ef2a 100644 --- a/src/capport/utils/sd_notify.py +++ b/src/capport/utils/sd_notify.py @@ -10,7 +10,7 @@ import trio.socket def _check_watchdog_pid() -> bool: - wpid = os.environ.pop('WATCHDOG_PID', None) + wpid = os.environ.pop("WATCHDOG_PID", None) if not wpid: return True return wpid == str(os.getpid()) @@ -18,16 +18,16 @@ def _check_watchdog_pid() -> bool: @contextlib.asynccontextmanager async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]: - target = os.environ.pop('NOTIFY_SOCKET', None) + target = os.environ.pop("NOTIFY_SOCKET", None) ns: trio.socket.SocketType | None = None watchdog_usec: int = 0 if target: - if target.startswith('@'): + if target.startswith("@"): # Linux abstract namespace socket - target = '\0' + target[1:] + target = "\0" + target[1:] ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) await ns.connect(target) - watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None) + watchdog_usec_s = os.environ.pop("WATCHDOG_USEC", None) if _check_watchdog_pid() and watchdog_usec_s: watchdog_usec = int(watchdog_usec_s) try: @@ -52,18 +52,18 @@ class SdNotify: async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None: assert self.is_connected(), "Watchdog can't run without socket" - await self.send('WATCHDOG=1') + await self.send("WATCHDOG=1") task_status.started() # send every half of the watchdog timeout - interval = (watchdog_usec/1e6) / 2.0 + interval = (watchdog_usec / 1e6) / 2.0 while True: await trio.sleep(interval) - await self.send('WATCHDOG=1') + await self.send("WATCHDOG=1") async def send(self, *msg: str) -> None: if not self.is_connected(): return - dgram = '\n'.join(msg).encode('utf-8') + dgram = "\n".join(msg).encode("utf-8") assert self._ns, "not connected" # checked above sent = await self._ns.send(dgram) if sent != len(dgram): diff --git a/src/capport/utils/zoneinfo.py b/src/capport/utils/zoneinfo.py index e9567b8..5d1ea92 100644 --- a/src/capport/utils/zoneinfo.py +++ b/src/capport/utils/zoneinfo.py @@ -10,9 +10,9 @@ def get_local_timezone(): global _zoneinfo if not _zoneinfo: try: - with open('/etc/timezone') as f: + with open("/etc/timezone") as f: key = f.readline().strip() _zoneinfo = zoneinfo.ZoneInfo(key) except (OSError, zoneinfo.ZoneInfoNotFoundError): - _zoneinfo = zoneinfo.ZoneInfo('UTC') + _zoneinfo = zoneinfo.ZoneInfo("UTC") return _zoneinfo