diff --git a/src/capport/api/setup.py b/src/capport/api/setup.py index f8c23d5..b9707b1 100644 --- a/src/capport/api/setup.py +++ b/src/capport/api/setup.py @@ -10,6 +10,8 @@ import capport.utils.cli import capport.utils.ipneigh import trio +from capport.utils.sd_notify import open_sdnotify + from .app import app @@ -38,9 +40,19 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: await app.shutdown() +async def _setup(*, task_status=trio.TASK_STATUS_IGNORED): + async with open_sdnotify() as sn: + await sn.send('STATUS=Starting hub') + async with trio.open_nursery() as nursery: + await nursery.start(_run_hub) + await sn.send('READY=1', 'STATUS=Ready for client requests') + task_status.started() + # continue running hub and systemd watchdog handler + + @app.before_serving async def init(): app.debug = app.my_config.debug app.secret_key = app.my_config.cookie_secret capport.utils.cli.init_logger(app.my_config) - await app.nursery.start(_run_hub) + await app.nursery.start(_setup) diff --git a/src/capport/control/run.py b/src/capport/control/run.py index e32bcfb..ffc0871 100644 --- a/src/capport/control/run.py +++ b/src/capport/control/run.py @@ -12,6 +12,8 @@ import capport.utils.cli import capport.utils.nft_set import trio from capport import cptypes +from capport.utils.sd_notify import open_sdnotify + _logger = logging.getLogger(__name__) @@ -42,13 +44,16 @@ class ControlApp(capport.comm.hub.HubApplication): async def amain(config: capport.config.Config) -> None: - app = ControlApp() - hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True) - app.hub = hub - async with trio.open_nursery() as nursery: - # hub.run loads the statefile from disk before signalling it was "started" - await nursery.start(hub.run) - app.apply_db_entries(hub.database.entries()) + async with open_sdnotify() as sn: + app = ControlApp() + hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True) + app.hub = hub + async with trio.open_nursery() as nursery: + # hub.run loads the statefile from disk before signalling it was "started" + await nursery.start(hub.run) + await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set') + app.apply_db_entries(hub.database.entries()) + await sn.send('STATUS=Kernel fully synchronized') def main() -> None: diff --git a/src/capport/utils/sd_notify.py b/src/capport/utils/sd_notify.py new file mode 100644 index 0000000..66c52d9 --- /dev/null +++ b/src/capport/utils/sd_notify.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import typing +import trio +import trio.socket +import contextlib +import os +import socket + + +def _check_watchdog_pid() -> bool: + wpid = os.environ.pop('WATCHDOG_PID', None) + if not wpid: + return True + return wpid == str(os.getpid()) + + +@contextlib.asynccontextmanager +async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]: + target = os.environ.pop('NOTIFY_SOCKET', None) + ns: typing.Optional[trio.socket.SocketType] = None + watchdog_usec: int = 0 + if target: + if target.startswith('@'): + # Linux abstract namespace socket + target = '\0' + target[1:] + ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + await ns.connect(target) + watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None) + if _check_watchdog_pid() and watchdog_usec_s: + watchdog_usec = int(watchdog_usec_s) + try: + async with trio.open_nursery() as nursery: + sn = SdNotify(_ns=ns) + if watchdog_usec: + await nursery.start(sn._run_watchdog, watchdog_usec) + yield sn + # stop watchdoch + nursery.cancel_scope.cancel() + finally: + if ns: + ns.close() + + +class SdNotify: + def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None: + self._ns = _ns + + def is_connected(self) -> bool: + return not (self._ns is None) + + async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None: + assert self.is_connected(), "Watchdog can't run without socket" + await self.send('WATCHDOG=1') + task_status.started() + # send every half of the watchdog timeout + interval = (watchdog_usec/1e6) / 2.0 + while True: + await trio.sleep(interval) + await self.send('WATCHDOG=1') + + async def send(self, *msg: str) -> None: + if not self.is_connected(): + return + dgram = '\n'.join(msg).encode('utf-8') + sent = await self._ns.send(dgram) + if sent != len(dgram): + raise OSError("Sent incomplete datagram to NOTIFY_SOCKET") +