diff --git a/.gitignore b/.gitignore index 8be8a4b..96d6443 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ __pycache__ venv capport.yaml custom +capport.state +capport.state.new-* diff --git a/src/capport/api/setup.py b/src/capport/api/setup.py index 7437e1a..f8c23d5 100644 --- a/src/capport/api/setup.py +++ b/src/capport/api/setup.py @@ -28,7 +28,7 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: app.my_nc = mync _logger.info("Running hub for API") myapp = ApiHubApp() - myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp) + myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False) app.my_hub = myhub await myhub.run(task_status=task_status) finally: diff --git a/src/capport/api/views.py b/src/capport/api/views.py index cafad84..e7cc3ea 100644 --- a/src/capport/api/views.py +++ b/src/capport/api/views.py @@ -56,28 +56,28 @@ async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cp async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None: assert app.my_hub # for mypy - pu = capport.database.PendingUpdates() - try: - app.my_hub.database.login(mac, app.my_config.session_timeout, pending_updates=pu) - except capport.database.NotReadyYet as e: - quart.abort(500, str(e)) + async with app.my_hub.database.make_changes() as pu: + try: + pu.login(mac, app.my_config.session_timeout) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) - if pu.macs: + if pu: _logger.debug(f'User {mac} (with IP {address}) logged in') - for msg in pu.serialize(): + for msg in pu.serialized: await app.my_hub.broadcast(msg) async def user_logout(mac: cptypes.MacAddress) -> None: assert app.my_hub # for mypy - pu = capport.database.PendingUpdates() - try: - app.my_hub.database.logout(mac, pending_updates=pu) - except capport.database.NotReadyYet as e: - quart.abort(500, str(e)) - if pu.macs: + async with app.my_hub.database.make_changes() as pu: + try: + pu.logout(mac) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) + if pu: _logger.debug(f'User {mac} logged out') - for msg in pu.serialize(): + for msg in pu.serialized: await app.my_hub.broadcast(msg) diff --git a/src/capport/comm/hub.py b/src/capport/comm/hub.py index 6f02aa0..4805ef0 100644 --- a/src/capport/comm/hub.py +++ b/src/capport/comm/hub.py @@ -265,9 +265,6 @@ class ControllerConn: class HubApplication: - def is_controller(self) -> bool: - return False - async def new_peer(self, *, peer_id: uuid.UUID) -> None: _logger.info(f"New peer {peer_id}") @@ -287,13 +284,18 @@ class HubApplication: class Hub: - def __init__(self, config: Config, app: HubApplication) -> None: + def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None: self._config = config self._instance_id = uuid.uuid4() self._hostname = socket.getfqdn() - self.database = capport.database.Database() self._app = app - self._is_controller = bool(app.is_controller()) + self._is_controller = is_controller + state_filename: typing.Optional[str] + if is_controller: + state_filename = 'capport.state' + else: + state_filename = None + self.database = capport.database.Database(state_filename=state_filename) self._anon_context = ssl.SSLContext() # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2 @@ -324,6 +326,7 @@ class Hub: async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): async with trio.open_nursery() as nursery: + await nursery.start(self.database.run) if self._is_controller: await nursery.start(self._listen) @@ -412,12 +415,12 @@ class Hub: pass elif isinstance(variant, capport.comm.message.MacStates): await self._app.received_mac_state(from_peer_id=peer_id, states=variant) - pu = capport.database.PendingUpdates() - for state in variant.states: - self.database.received_mac_state(state, pending_updates=pu) - if pu.macs: + async with self.database.make_changes() as pu: + for state in variant.states: + pu.received_mac_state(state) + if pu: # re-broadcast all received updates to all peers - await self.broadcast(*pu.serialize(), exclude=peer_id) + await self.broadcast(*pu.serialized, exclude=peer_id) await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu) else: await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg) diff --git a/src/capport/control/run.py b/src/capport/control/run.py index 4ce75f2..e32bcfb 100644 --- a/src/capport/control/run.py +++ b/src/capport/control/run.py @@ -1,12 +1,13 @@ from __future__ import annotations import logging +import typing import uuid -import capport.database import capport.comm.hub import capport.comm.message import capport.config +import capport.database import capport.utils.cli import capport.utils.nft_set import trio @@ -22,15 +23,15 @@ class ControlApp(capport.comm.hub.HubApplication): super().__init__() self.nft_set = capport.utils.nft_set.NftSet() - def is_controller(self) -> bool: - return True - async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + self.apply_db_entries(pending_updates.changes()) + + def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None: # deploy changes to netfilter set inserts = [] removals = [] now = cptypes.Timestamp.now() - for mac, state in pending_updates.macs.items(): + for mac, state in entries: rem = state.allowed_remaining(now) if rem > 0: inserts.append((mac, rem)) @@ -42,9 +43,12 @@ class ControlApp(capport.comm.hub.HubApplication): async def amain(config: capport.config.Config) -> None: app = ControlApp() - hub = capport.comm.hub.Hub(config=config, app=app) + hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True) app.hub = hub - await hub.run() + async with trio.open_nursery() as nursery: + # hub.run loads the statefile from disk before signalling it was "started" + await nursery.start(hub.run) + app.apply_db_entries(hub.database.entries()) def main() -> None: diff --git a/src/capport/database.py b/src/capport/database.py index c634c49..197b588 100644 --- a/src/capport/database.py +++ b/src/capport/database.py @@ -1,11 +1,20 @@ from __future__ import annotations +import contextlib import dataclasses +import logging +import os +import struct import typing +import google.protobuf.message +import trio + import capport.comm.message from capport import cptypes +_logger = logging.getLogger(__name__) + @dataclasses.dataclass class MacEntry: @@ -113,30 +122,82 @@ def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacE return [s.to_message() for s in _serialize_mac_states(macs)] +def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes: + chunk = states.SerializeToString(deterministic=True) + chunk_size = len(chunk) + len_bytes = struct.pack('!I', chunk_size) + return len_bytes + chunk + + class NotReadyYet(Exception): def __init__(self, msg: str, wait: int): self.wait = wait # seconds to wait super().__init__(msg) -@dataclasses.dataclass class Database: - _macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) + def __init__(self, state_filename: typing.Optional[str]=None): + self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {} + self._state_filename = state_filename + self._changed_since_last_cleanup = False + self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[ + capport.comm.message.MacStates, + typing.List[capport.comm.message.MacStates], + ]]] = None - def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates): - (addr, new_entry) = MacEntry.parse_state(state) - old_entry = self._macs.get(addr) - if not old_entry: - # only redistribute if not outdated - if not new_entry.outdated(): - self._macs[addr] = new_entry - pending_updates.macs[addr] = new_entry - elif old_entry.merge(new_entry): - if old_entry.outdated(): - # remove local entry, but still redistribute - self._macs.pop(addr) - pending_updates.macs[addr] = old_entry + @contextlib.asynccontextmanager + async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]: + pu = PendingUpdates(self) + pu._closed = False + yield pu + pu._finish() + if pu: + self._changed_since_last_cleanup = True + if self._send_changes: + for state in pu.serialized_states: + await self._send_changes.send(state) + def _drop_outdated(self) -> None: + done = False + while not done: + depr: typing.Set[cptypes.MacAddress] = set() + now = cptypes.Timestamp.now() + done = True + for mac, entry in self._macs.items(): + if entry.outdated(now): + depr.add(mac) + if len(depr) >= 1024: + # clear entries found so far, then try again + done = False + break + if depr: + self._changed_since_last_cleanup = True + for mac in depr: + del self._macs[mac] + + async def run(self, task_status=trio.TASK_STATUS_IGNORED): + if self._state_filename: + await self._load_statefile() + task_status.started() + async with trio.open_nursery() as nursery: + if self._state_filename: + nursery.start_soon(self._run_statefile) + while True: + await trio.sleep(10) # sleep 15 minutes + _logger.debug("Running database cleanup") + self._drop_outdated() + if self._changed_since_last_cleanup: + self._changed_since_last_cleanup = False + if self._send_changes: + states = _serialize_mac_states(self._macs) + # trigger a resync + await self._send_changes.send(states) + + # for initial handling of all data + def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]: + return list(self._macs.items()) + + # for initial sync with new peer def serialize(self) -> typing.List[capport.comm.message.Message]: return _serialize_mac_states_as_messages(self._macs) @@ -146,6 +207,87 @@ class Database: for addr, entry in self._macs.items() } + async def _run_statefile(self) -> None: + rx: trio.MemoryReceiveChannel[typing.Union[ + capport.comm.message.MacStates, + typing.List[capport.comm.message.MacStates], + ]] + tx: trio.MemorySendChannel[typing.Union[ + capport.comm.message.MacStates, + typing.List[capport.comm.message.MacStates], + ]] + tx, rx = trio.open_memory_channel(64) + self._send_changes = tx + + assert self._state_filename + filename: str = self._state_filename + tmp_filename = f'{filename}.new-{os.getpid()}' + + async def resync(all_states: typing.List[capport.comm.message.MacStates]): + try: + async with await trio.open_file(tmp_filename, 'xb') as tf: + for states in all_states: + await tf.write(_states_to_chunk(states)) + os.rename(tmp_filename, filename) + finally: + if os.path.exists(tmp_filename): + _logger.warning(f'Removing (failed) state export file {tmp_filename}') + os.unlink(tmp_filename) + + try: + while True: + async with await trio.open_file(filename, 'ab') as sf: + while True: + update = await rx.receive() + if isinstance(update, list): + break + await sf.write(_states_to_chunk(update)) + # got a "list" update - i.e. a resync + with trio.CancelScope(shield=True): + await resync(update) + # now reopen normal statefile and continue appending updates + except trio.Cancelled: + _logger.info('Final sync to disk') + with trio.CancelScope(shield=True): + await resync(_serialize_mac_states(self._macs)) + _logger.info('Final sync to disk done') + + async def _load_statefile(self): + if not os.path.exists(self._state_filename): + return + _logger.info("Loading statefile") + # we're going to ignore changes from loading the file + pu = PendingUpdates(self) + pu._closed = False + async with await trio.open_file(self._state_filename, 'rb') as sf: + while True: + try: + len_bytes = await sf.read(4) + if not len_bytes: + return + if len(len_bytes) < 4: + _logger.error("Failed to read next chunk from statefile (unexpected EOF)") + return + chunk_size, = struct.unpack('!I', len_bytes) + chunk = await sf.read(chunk_size) + except IOError as e: + _logger.error(f"Failed to read next chunk from statefile: {e}") + return + try: + states = capport.comm.message.MacStates() + states.ParseFromString(chunk) + except google.protobuf.message.DecodeError as e: + _logger.error(f"Failed to decode chunk from statefile, trying next one: {e}") + continue + for state in states: + errors = 0 + try: + pu.received_mac_state(state) + except Exception as e: + errors += 1 + if errors < 5: + _logger.error(f'Failed to handle state: {e}') + def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: entry = self._macs.get(mac) if entry: @@ -158,43 +300,87 @@ class Database: allowed_remaining=allowed_remaining, ) - def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8): + +class PendingUpdates: + def __init__(self, database: Database): + self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {} + self._database = database + self._closed = True + self._serialized_states: typing.List[capport.comm.message.MacStates] = [] + self._serialized: typing.List[capport.comm.message.Message] = [] + + def __bool__(self) -> bool: + return bool(self._changes) + + def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]: + return self._changes.items() + + @property + def serialized_states(self) -> typing.List[capport.comm.message.MacStates]: + assert self._closed + return self._serialized_states + + @property + def serialized(self) -> typing.List[capport.comm.message.Message]: + assert self._closed + return self._serialized + + def _finish(self): + if self._closed: + raise Exception("Can't change closed PendingUpdates") + self._closed = True + self._serialized_states = _serialize_mac_states(self._changes) + self._serialized = [s.to_message() for s in self._serialized_states] + + def received_mac_state(self, state: capport.comm.message.MacState): + if self._closed: + raise Exception("Can't change closed PendingUpdates") + (addr, new_entry) = MacEntry.parse_state(state) + old_entry = self._database._macs.get(addr) + if not old_entry: + # only redistribute if not outdated + if not new_entry.outdated(): + self._database._macs[addr] = new_entry + self._changes[addr] = new_entry + elif old_entry.merge(new_entry): + if old_entry.outdated(): + # remove local entry, but still redistribute + self._database._macs.pop(addr) + self._changes[addr] = old_entry + + def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8): + if self._closed: + raise Exception("Can't change closed PendingUpdates") now = cptypes.Timestamp.now() allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout) new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True) - entry = self._macs.get(mac) + entry = self._database._macs.get(mac) if not entry: - self._macs[mac] = new_entry - pending_updates.macs[mac] = new_entry + self._database._macs[mac] = new_entry + self._changes[mac] = new_entry elif entry.allowed_remaining(now) > renew_maximum * session_timeout: # too much time left on clock, not renewing session return elif entry.merge(new_entry): - pending_updates.macs[mac] = entry + self._changes[mac] = entry elif not entry.allowed_remaining() > 0: # entry should have been updated - can only fail due to `now < entry.last_change` # i.e. out of sync clocks wait = entry.last_change.epoch - now.epoch raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait) - def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates): + def logout(self, mac: cptypes.MacAddress): + if self._closed: + raise Exception("Can't change closed PendingUpdates") now = cptypes.Timestamp.now() new_entry = MacEntry(last_change=now, allow_until=None, allowed=False) - entry = self._macs.get(mac) + entry = self._database._macs.get(mac) if entry: if entry.merge(new_entry): - pending_updates.macs[mac] = entry + self._changes[mac] = entry elif entry.allowed_remaining() > 0: # still logged in. can only happen with `now <= entry.last_change` # clocks not necessarily out of sync, but you can't logout in the same second you logged in wait = entry.last_change.epoch - now.epoch + 1 raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait) - - -@dataclasses.dataclass -class PendingUpdates: - macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) - - def serialize(self) -> typing.List[capport.comm.message.Message]: - return _serialize_mac_states_as_messages(self.macs)