3
0

86 lines
2.4 KiB
Python

from __future__ import annotations
import argparse
import dataclasses
import typing
import uuid
import trio
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.database
import capport.utils.cli
import capport.utils.nft_set
from capport import cptypes
from capport.utils.sd_notify import open_sdnotify
class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub
def __init__(self) -> None:
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
self.apply_db_entries(pending_updates.changes())
def apply_db_entries(
self,
entries: typing.Iterable[tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None:
# deploy changes to netfilter set
inserts = []
removals = []
now = cptypes.Timestamp.now()
for mac, state in entries:
rem = state.allowed_remaining(now)
if rem > 0:
inserts.append((mac, rem))
else:
removals.append(mac)
self.nft_set.bulk_insert(inserts)
self.nft_set.bulk_remove(removals)
async def amain(config: capport.config.Config) -> None:
async with open_sdnotify() as sn:
app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
app.hub = hub
async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run)
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
app.apply_db_entries(hub.database.entries())
await sn.send('STATUS=Kernel fully synchronized')
@dataclasses.dataclass
class CliArguments:
config: str | None
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c')
args = parser.parse_args()
self.config = args.config
def main() -> None:
args = CliArguments()
config = capport.config.Config.load_default_once(filename=args.config)
capport.utils.cli.init_logger(config)
try:
trio.run(amain, config)
except (KeyboardInterrupt, InterruptedError):
print()