from __future__ import annotations import contextlib import dataclasses import logging import os import struct import typing import google.protobuf.message import trio import capport.comm.message from capport import cptypes _logger = logging.getLogger(__name__) @dataclasses.dataclass class MacEntry: # entry can be removed if last_change was some time ago and allow_until wasn't set # or got reached. WAIT_LAST_CHANGE_SECONDS = 60 WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10 # last_change: timestamp of last change (sent by system initiating the change) last_change: cptypes.Timestamp # only if allowed is true and allow_until is set the device can communicate with the internet # allow_until must not go backwards (and not get unset) allow_until: typing.Optional[cptypes.Timestamp] allowed: bool @staticmethod def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]: if len(msg.mac_address) < 6: raise Exception("Invalid MacState: mac_address too short") addr = cptypes.MacAddress(raw=msg.mac_address) last_change = cptypes.Timestamp.from_protobuf(msg.last_change) if not last_change: raise Exception(f"Invalid MacState[{addr}]: missing last_change") allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until) return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed)) def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState: allow_until = 0 if self.allow_until: allow_until = self.allow_until.epoch return capport.comm.message.MacState( mac_address=addr.raw, last_change=self.last_change.epoch, allow_until=allow_until, allowed=self.allowed, ) def as_json(self) -> dict: allow_until = None if self.allow_until: allow_until = self.allow_until.epoch return dict( last_change=self.last_change.epoch, allow_until=allow_until, allowed=self.allowed, ) def merge(self, new: MacEntry) -> bool: changed = False if new.last_change > self.last_change: changed = True self.last_change = new.last_change self.allowed = new.allowed elif new.last_change == self.last_change: # same last_change: set allowed if one allowed if new.allowed and not self.allowed: changed = True self.allowed = True # set allow_until to max of both if new.allow_until: # if not set nothing to change in local data if not self.allow_until or self.allow_until < new.allow_until: changed = True self.allow_until = new.allow_until return changed def timeout(self) -> cptypes.Timestamp: elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS if self.allow_until: eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS if eau > elc: return cptypes.Timestamp(epoch=eau) return cptypes.Timestamp(epoch=elc) # returns 0 if not allowed def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int: if not self.allowed or not self.allow_until: return 0 if not now: now = cptypes.Timestamp.now() assert self.allow_until return max(self.allow_until.epoch - now.epoch, 0) def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool: if not now: now = cptypes.Timestamp.now() return now.epoch > self.timeout().epoch # might use this to serialize into file - don't need Message variant there def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: result: typing.List[capport.comm.message.MacStates] = [] current = capport.comm.message.MacStates() for addr, entry in macs.items(): state = entry.to_state(addr) current.states.append(state) if len(current.states) >= 1024: # split into messages with 1024 states result.append(current) current = capport.comm.message.MacStates() if len(current.states): result.append(current) return result def _serialize_mac_states_as_messages( macs: dict[cptypes.MacAddress, MacEntry], ) -> typing.List[capport.comm.message.Message]: return [s.to_message() for s in _serialize_mac_states(macs)] def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes: chunk = states.SerializeToString(deterministic=True) chunk_size = len(chunk) len_bytes = struct.pack('!I', chunk_size) return len_bytes + chunk class NotReadyYet(Exception): def __init__(self, msg: str, wait: int): self.wait = wait # seconds to wait super().__init__(msg) class Database: def __init__(self, state_filename: typing.Optional[str] = None): self._macs: dict[cptypes.MacAddress, MacEntry] = {} self._state_filename = state_filename self._changed_since_last_cleanup = False self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[ capport.comm.message.MacStates, typing.List[capport.comm.message.MacStates], ]]] = None @contextlib.asynccontextmanager async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]: pu = PendingUpdates(self) pu._closed = False yield pu pu._finish() if pu: self._changed_since_last_cleanup = True if self._send_changes: for state in pu.serialized_states: await self._send_changes.send(state) def _drop_outdated(self) -> None: done = False while not done: depr: typing.Set[cptypes.MacAddress] = set() now = cptypes.Timestamp.now() done = True for mac, entry in self._macs.items(): if entry.outdated(now): depr.add(mac) if len(depr) >= 1024: # clear entries found so far, then try again done = False break if depr: self._changed_since_last_cleanup = True for mac in depr: del self._macs[mac] async def run(self, task_status=trio.TASK_STATUS_IGNORED): if self._state_filename: await self._load_statefile() task_status.started() async with trio.open_nursery() as nursery: if self._state_filename: nursery.start_soon(self._run_statefile) while True: await trio.sleep(300) # cleanup every 5 minutes _logger.debug("Running database cleanup") self._drop_outdated() if self._changed_since_last_cleanup: self._changed_since_last_cleanup = False if self._send_changes: states = _serialize_mac_states(self._macs) # trigger a resync await self._send_changes.send(states) # for initial handling of all data def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]: return list(self._macs.items()) # for initial sync with new peer def serialize(self) -> typing.List[capport.comm.message.Message]: return _serialize_mac_states_as_messages(self._macs) def as_json(self) -> dict: return { str(addr): entry.as_json() for addr, entry in self._macs.items() } async def _run_statefile(self) -> None: rx: trio.MemoryReceiveChannel[typing.Union[ capport.comm.message.MacStates, typing.List[capport.comm.message.MacStates], ]] tx: trio.MemorySendChannel[typing.Union[ capport.comm.message.MacStates, typing.List[capport.comm.message.MacStates], ]] tx, rx = trio.open_memory_channel(64) self._send_changes = tx assert self._state_filename filename: str = self._state_filename tmp_filename = f'{filename}.new-{os.getpid()}' async def resync(all_states: typing.List[capport.comm.message.MacStates]): try: async with await trio.open_file(tmp_filename, 'xb') as tf: for states in all_states: await tf.write(_states_to_chunk(states)) os.rename(tmp_filename, filename) finally: if os.path.exists(tmp_filename): _logger.warning(f'Removing (failed) state export file {tmp_filename}') os.unlink(tmp_filename) try: while True: async with await trio.open_file(filename, 'ab', buffering=0) as sf: while True: update = await rx.receive() if isinstance(update, list): break await sf.write(_states_to_chunk(update)) # got a "list" update - i.e. a resync with trio.CancelScope(shield=True): await resync(update) # now reopen normal statefile and continue appending updates except trio.Cancelled: _logger.info('Final sync to disk') with trio.CancelScope(shield=True): await resync(_serialize_mac_states(self._macs)) _logger.info('Final sync to disk done') async def _load_statefile(self): if not os.path.exists(self._state_filename): return _logger.info("Loading statefile") # we're going to ignore changes from loading the file pu = PendingUpdates(self) pu._closed = False async with await trio.open_file(self._state_filename, 'rb') as sf: while True: try: len_bytes = await sf.read(4) if not len_bytes: return if len(len_bytes) < 4: _logger.error("Failed to read next chunk from statefile (unexpected EOF)") return chunk_size, = struct.unpack('!I', len_bytes) chunk = await sf.read(chunk_size) except IOError as e: _logger.error(f"Failed to read next chunk from statefile: {e}") return try: states = capport.comm.message.MacStates() states.ParseFromString(chunk) except google.protobuf.message.DecodeError as e: _logger.error(f"Failed to decode chunk from statefile, trying next one: {e}") continue for state in states.states: errors = 0 try: pu.received_mac_state(state) except Exception as e: errors += 1 if errors < 5: _logger.error(f'Failed to handle state: {e}') def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: entry = self._macs.get(mac) if entry: allowed_remaining = entry.allowed_remaining() else: allowed_remaining = 0 return cptypes.MacPublicState( address=address, mac=mac, allowed_remaining=allowed_remaining, ) class PendingUpdates: def __init__(self, database: Database): self._changes: dict[cptypes.MacAddress, MacEntry] = {} self._database = database self._closed = True self._serialized_states: typing.List[capport.comm.message.MacStates] = [] self._serialized: typing.List[capport.comm.message.Message] = [] def __bool__(self) -> bool: return bool(self._changes) def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]: return self._changes.items() @property def serialized_states(self) -> typing.List[capport.comm.message.MacStates]: assert self._closed return self._serialized_states @property def serialized(self) -> typing.List[capport.comm.message.Message]: assert self._closed return self._serialized def _finish(self): if self._closed: raise Exception("Can't change closed PendingUpdates") self._closed = True self._serialized_states = _serialize_mac_states(self._changes) self._serialized = [s.to_message() for s in self._serialized_states] def received_mac_state(self, state: capport.comm.message.MacState): if self._closed: raise Exception("Can't change closed PendingUpdates") (addr, new_entry) = MacEntry.parse_state(state) old_entry = self._database._macs.get(addr) if not old_entry: # only redistribute if not outdated if not new_entry.outdated(): self._database._macs[addr] = new_entry self._changes[addr] = new_entry elif old_entry.merge(new_entry): if old_entry.outdated(): # remove local entry, but still redistribute self._database._macs.pop(addr) self._changes[addr] = old_entry def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8): if self._closed: raise Exception("Can't change closed PendingUpdates") now = cptypes.Timestamp.now() allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout) new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True) entry = self._database._macs.get(mac) if not entry: self._database._macs[mac] = new_entry self._changes[mac] = new_entry elif entry.allowed_remaining(now) > renew_maximum * session_timeout: # too much time left on clock, not renewing session return elif entry.merge(new_entry): self._changes[mac] = entry elif not entry.allowed_remaining() > 0: # entry should have been updated - can only fail due to `now < entry.last_change` # i.e. out of sync clocks wait = entry.last_change.epoch - now.epoch raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait) def logout(self, mac: cptypes.MacAddress): if self._closed: raise Exception("Can't change closed PendingUpdates") now = cptypes.Timestamp.now() new_entry = MacEntry(last_change=now, allow_until=None, allowed=False) entry = self._database._macs.get(mac) if entry: if entry.merge(new_entry): self._changes[mac] = entry elif entry.allowed_remaining() > 0: # still logged in. can only happen with `now <= entry.last_change` # clocks not necessarily out of sync, but you can't logout in the same second you logged in wait = entry.last_change.epoch - now.epoch + 1 raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)