201 lines
8.0 KiB
Python
201 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import typing
|
|
|
|
import capport.comm.message
|
|
from capport import cptypes
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MacEntry:
|
|
# entry can be removed if last_change was some time ago and allow_until wasn't set
|
|
# or got reached.
|
|
WAIT_LAST_CHANGE_SECONDS = 60
|
|
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
|
|
|
|
# last_change: timestamp of last change (sent by system initiating the change)
|
|
last_change: cptypes.Timestamp
|
|
# only if allowed is true and allow_until is set the device can communicate with the internet
|
|
# allow_until must not go backwards (and not get unset)
|
|
allow_until: typing.Optional[cptypes.Timestamp]
|
|
allowed: bool
|
|
|
|
@staticmethod
|
|
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
|
|
if len(msg.mac_address) < 6:
|
|
raise Exception("Invalid MacState: mac_address too short")
|
|
addr = cptypes.MacAddress(raw=msg.mac_address)
|
|
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
|
|
if not last_change:
|
|
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
|
|
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
|
|
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
|
|
|
|
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
|
|
allow_until = 0
|
|
if self.allow_until:
|
|
allow_until = self.allow_until.epoch
|
|
return capport.comm.message.MacState(
|
|
mac_address=addr.raw,
|
|
last_change=self.last_change.epoch,
|
|
allow_until=allow_until,
|
|
allowed=self.allowed,
|
|
)
|
|
|
|
def as_json(self) -> dict:
|
|
allow_until = None
|
|
if self.allow_until:
|
|
allow_until = self.allow_until.epoch
|
|
return dict(
|
|
last_change=self.last_change.epoch,
|
|
allow_until=allow_until,
|
|
allowed=self.allowed,
|
|
)
|
|
|
|
def merge(self, new: MacEntry) -> bool:
|
|
changed = False
|
|
if new.last_change > self.last_change:
|
|
changed = True
|
|
self.last_change = new.last_change
|
|
self.allowed = new.allowed
|
|
elif new.last_change == self.last_change:
|
|
# same last_change: set allowed if one allowed
|
|
if new.allowed and not self.allowed:
|
|
changed = True
|
|
self.allowed = True
|
|
# set allow_until to max of both
|
|
if new.allow_until: # if not set nothing to change in local data
|
|
if not self.allow_until or self.allow_until < new.allow_until:
|
|
changed = True
|
|
self.allow_until = new.allow_until
|
|
return changed
|
|
|
|
def timeout(self) -> cptypes.Timestamp:
|
|
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
|
|
if self.allow_until:
|
|
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
|
|
if eau > elc:
|
|
return cptypes.Timestamp(epoch=eau)
|
|
return cptypes.Timestamp(epoch=elc)
|
|
|
|
# returns 0 if not allowed
|
|
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
|
|
if not self.allowed or not self.allow_until:
|
|
return 0
|
|
if not now:
|
|
now = cptypes.Timestamp.now()
|
|
assert self.allow_until
|
|
return max(self.allow_until.epoch - now.epoch, 0)
|
|
|
|
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
|
|
if not now:
|
|
now = cptypes.Timestamp.now()
|
|
return now.epoch > self.timeout().epoch
|
|
|
|
|
|
# might use this to serialize into file - don't need Message variant there
|
|
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
|
|
result: typing.List[capport.comm.message.MacStates] = []
|
|
current = capport.comm.message.MacStates()
|
|
for addr, entry in macs.items():
|
|
state = entry.to_state(addr)
|
|
current.states.append(state)
|
|
if len(current.states) >= 1024: # split into messages with 1024 states
|
|
result.append(current)
|
|
current = capport.comm.message.MacStates()
|
|
if len(current.states):
|
|
result.append(current)
|
|
return result
|
|
|
|
|
|
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
|
|
return [s.to_message() for s in _serialize_mac_states(macs)]
|
|
|
|
|
|
class NotReadyYet(Exception):
|
|
def __init__(self, msg: str, wait: int):
|
|
self.wait = wait # seconds to wait
|
|
super().__init__(msg)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Database:
|
|
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
|
|
|
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
|
|
(addr, new_entry) = MacEntry.parse_state(state)
|
|
old_entry = self._macs.get(addr)
|
|
if not old_entry:
|
|
# only redistribute if not outdated
|
|
if not new_entry.outdated():
|
|
self._macs[addr] = new_entry
|
|
pending_updates.macs[addr] = new_entry
|
|
elif old_entry.merge(new_entry):
|
|
if old_entry.outdated():
|
|
# remove local entry, but still redistribute
|
|
self._macs.pop(addr)
|
|
pending_updates.macs[addr] = old_entry
|
|
|
|
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
|
return _serialize_mac_states_as_messages(self._macs)
|
|
|
|
def as_json(self) -> dict:
|
|
return {
|
|
str(addr): entry.as_json()
|
|
for addr, entry in self._macs.items()
|
|
}
|
|
|
|
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
|
entry = self._macs.get(mac)
|
|
if entry:
|
|
allowed_remaining = entry.allowed_remaining()
|
|
else:
|
|
allowed_remaining = 0
|
|
return cptypes.MacPublicState(
|
|
address=address,
|
|
mac=mac,
|
|
allowed_remaining=allowed_remaining,
|
|
)
|
|
|
|
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
|
|
now = cptypes.Timestamp.now()
|
|
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
|
|
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
|
|
|
|
entry = self._macs.get(mac)
|
|
if not entry:
|
|
self._macs[mac] = new_entry
|
|
pending_updates.macs[mac] = new_entry
|
|
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
|
|
# too much time left on clock, not renewing session
|
|
return
|
|
elif entry.merge(new_entry):
|
|
pending_updates.macs[mac] = entry
|
|
elif not entry.allowed_remaining() > 0:
|
|
# entry should have been updated - can only fail due to `now < entry.last_change`
|
|
# i.e. out of sync clocks
|
|
wait = entry.last_change.epoch - now.epoch
|
|
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
|
|
|
|
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
|
|
now = cptypes.Timestamp.now()
|
|
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
|
|
entry = self._macs.get(mac)
|
|
if entry:
|
|
if entry.merge(new_entry):
|
|
pending_updates.macs[mac] = entry
|
|
elif entry.allowed_remaining() > 0:
|
|
# still logged in. can only happen with `now <= entry.last_change`
|
|
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
|
|
wait = entry.last_change.epoch - now.epoch + 1
|
|
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PendingUpdates:
|
|
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
|
|
|
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
|
return _serialize_mac_states_as_messages(self.macs)
|