from __future__ import annotations import dataclasses import typing import capport.comm.message from capport import cptypes @dataclasses.dataclass class MacEntry: # entry can be removed if last_change was some time ago and allow_until wasn't set # or got reached. WAIT_LAST_CHANGE_SECONDS = 60 WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10 # last_change: timestamp of last change (sent by system initiating the change) last_change: cptypes.Timestamp # only if allowed is true and allow_until is set the device can communicate with the internet # allow_until must not go backwards (and not get unset) allow_until: typing.Optional[cptypes.Timestamp] allowed: bool @staticmethod def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]: if len(msg.mac_address) < 6: raise Exception("Invalid MacState: mac_address too short") addr = cptypes.MacAddress(raw=msg.mac_address) last_change = cptypes.Timestamp.from_protobuf(msg.last_change) if not last_change: raise Exception(f"Invalid MacState[{addr}]: missing last_change") allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until) return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed)) def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState: allow_until = 0 if self.allow_until: allow_until = self.allow_until.epoch return capport.comm.message.MacState( mac_address=addr.raw, last_change=self.last_change.epoch, allow_until=allow_until, allowed=self.allowed, ) def as_json(self) -> dict: allow_until = None if self.allow_until: allow_until = self.allow_until.epoch return dict( last_change=self.last_change.epoch, allow_until=allow_until, allowed=self.allowed, ) def merge(self, new: MacEntry) -> bool: changed = False if new.last_change > self.last_change: changed = True self.last_change = new.last_change self.allowed = new.allowed elif new.last_change == self.last_change: # same last_change: set allowed if one allowed if new.allowed and not self.allowed: changed = True self.allowed = True # set allow_until to max of both if new.allow_until: # if not set nothing to change in local data if not self.allow_until or self.allow_until < new.allow_until: changed = True self.allow_until = new.allow_until return changed def timeout(self) -> cptypes.Timestamp: elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS if self.allow_until: eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS if eau > elc: return cptypes.Timestamp(epoch=eau) return cptypes.Timestamp(epoch=elc) # returns 0 if not allowed def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int: if not self.allowed or not self.allow_until: return 0 if not now: now = cptypes.Timestamp.now() assert self.allow_until return max(self.allow_until.epoch - now.epoch, 0) def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool: if not now: now = cptypes.Timestamp.now() return now.epoch > self.timeout().epoch # might use this to serialize into file - don't need Message variant there def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: result: typing.List[capport.comm.message.MacStates] = [] current = capport.comm.message.MacStates() for addr, entry in macs.items(): state = entry.to_state(addr) current.states.append(state) if len(current.states) >= 1024: # split into messages with 1024 states result.append(current) current = capport.comm.message.MacStates() if len(current.states): result.append(current) return result def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]: return [s.to_message() for s in _serialize_mac_states(macs)] class NotReadyYet(Exception): def __init__(self, msg: str, wait: int): self.wait = wait # seconds to wait super().__init__(msg) @dataclasses.dataclass class Database: _macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates): (addr, new_entry) = MacEntry.parse_state(state) old_entry = self._macs.get(addr) if not old_entry: # only redistribute if not outdated if not new_entry.outdated(): self._macs[addr] = new_entry pending_updates.macs[addr] = new_entry elif old_entry.merge(new_entry): if old_entry.outdated(): # remove local entry, but still redistribute self._macs.pop(addr) pending_updates.macs[addr] = old_entry def serialize(self) -> typing.List[capport.comm.message.Message]: return _serialize_mac_states_as_messages(self._macs) def as_json(self) -> dict: return { str(addr): entry.as_json() for addr, entry in self._macs.items() } def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: entry = self._macs.get(mac) if entry: allowed_remaining = entry.allowed_remaining() else: allowed_remaining = 0 return cptypes.MacPublicState( address=address, mac=mac, allowed_remaining=allowed_remaining, ) def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8): now = cptypes.Timestamp.now() allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout) new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True) entry = self._macs.get(mac) if not entry: self._macs[mac] = new_entry pending_updates.macs[mac] = new_entry elif entry.allowed_remaining(now) > renew_maximum * session_timeout: # too much time left on clock, not renewing session return elif entry.merge(new_entry): pending_updates.macs[mac] = entry elif not entry.allowed_remaining() > 0: # entry should have been updated - can only fail due to `now < entry.last_change` # i.e. out of sync clocks wait = entry.last_change.epoch - now.epoch raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait) def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates): now = cptypes.Timestamp.now() new_entry = MacEntry(last_change=now, allow_until=None, allowed=False) entry = self._macs.get(mac) if entry: if entry.merge(new_entry): pending_updates.macs[mac] = entry elif entry.allowed_remaining() > 0: # still logged in. can only happen with `now <= entry.last_change` # clocks not necessarily out of sync, but you can't logout in the same second you logged in wait = entry.last_change.epoch - now.epoch + 1 raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait) @dataclasses.dataclass class PendingUpdates: macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) def serialize(self) -> typing.List[capport.comm.message.Message]: return _serialize_mac_states_as_messages(self.macs)