2
0
python-capport/src/capport/database.py

201 lines
8.0 KiB
Python
Raw Normal View History

2022-04-04 19:21:51 +02:00
from __future__ import annotations
import dataclasses
import typing
import capport.comm.message
from capport import cptypes
@dataclasses.dataclass
class MacEntry:
# entry can be removed if last_change was some time ago and allow_until wasn't set
# or got reached.
WAIT_LAST_CHANGE_SECONDS = 60
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
# last_change: timestamp of last change (sent by system initiating the change)
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
if not last_change:
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
allow_until = 0
if self.allow_until:
allow_until = self.allow_until.epoch
return capport.comm.message.MacState(
mac_address=addr.raw,
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def as_json(self) -> dict:
allow_until = None
if self.allow_until:
allow_until = self.allow_until.epoch
return dict(
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def merge(self, new: MacEntry) -> bool:
changed = False
if new.last_change > self.last_change:
changed = True
self.last_change = new.last_change
self.allowed = new.allowed
elif new.last_change == self.last_change:
# same last_change: set allowed if one allowed
if new.allowed and not self.allowed:
changed = True
self.allowed = True
# set allow_until to max of both
if new.allow_until: # if not set nothing to change in local data
if not self.allow_until or self.allow_until < new.allow_until:
changed = True
self.allow_until = new.allow_until
return changed
def timeout(self) -> cptypes.Timestamp:
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
if self.allow_until:
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
if eau > elc:
return cptypes.Timestamp(epoch=eau)
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
now = cptypes.Timestamp.now()
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
current.states.append(state)
if len(current.states) >= 1024: # split into messages with 1024 states
result.append(current)
current = capport.comm.message.MacStates()
if len(current.states):
result.append(current)
return result
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
@dataclasses.dataclass
class Database:
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._macs[addr] = new_entry
pending_updates.macs[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._macs.pop(addr)
pending_updates.macs[addr] = old_entry
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
allowed_remaining = entry.allowed_remaining()
else:
allowed_remaining = 0
return cptypes.MacPublicState(
address=address,
mac=mac,
allowed_remaining=allowed_remaining,
)
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._macs.get(mac)
if not entry:
self._macs[mac] = new_entry
pending_updates.macs[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._macs.get(mac)
if entry:
if entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
@dataclasses.dataclass
class PendingUpdates:
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self.macs)