390 lines
15 KiB
Python
390 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
import struct
|
|
import typing
|
|
|
|
import google.protobuf.message
|
|
|
|
import trio
|
|
|
|
import capport.comm.message
|
|
from capport import cptypes
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MacEntry:
|
|
# entry can be removed if last_change was some time ago and allow_until wasn't set
|
|
# or got reached.
|
|
WAIT_LAST_CHANGE_SECONDS = 60
|
|
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
|
|
|
|
# last_change: timestamp of last change (sent by system initiating the change)
|
|
last_change: cptypes.Timestamp
|
|
# only if allowed is true and allow_until is set the device can communicate with the internet
|
|
# allow_until must not go backwards (and not get unset)
|
|
allow_until: typing.Optional[cptypes.Timestamp]
|
|
allowed: bool
|
|
|
|
@staticmethod
|
|
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
|
|
if len(msg.mac_address) < 6:
|
|
raise Exception("Invalid MacState: mac_address too short")
|
|
addr = cptypes.MacAddress(raw=msg.mac_address)
|
|
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
|
|
if not last_change:
|
|
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
|
|
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
|
|
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
|
|
|
|
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
|
|
allow_until = 0
|
|
if self.allow_until:
|
|
allow_until = self.allow_until.epoch
|
|
return capport.comm.message.MacState(
|
|
mac_address=addr.raw,
|
|
last_change=self.last_change.epoch,
|
|
allow_until=allow_until,
|
|
allowed=self.allowed,
|
|
)
|
|
|
|
def as_json(self) -> dict:
|
|
allow_until = None
|
|
if self.allow_until:
|
|
allow_until = self.allow_until.epoch
|
|
return dict(
|
|
last_change=self.last_change.epoch,
|
|
allow_until=allow_until,
|
|
allowed=self.allowed,
|
|
)
|
|
|
|
def merge(self, new: MacEntry) -> bool:
|
|
changed = False
|
|
if new.last_change > self.last_change:
|
|
changed = True
|
|
self.last_change = new.last_change
|
|
self.allowed = new.allowed
|
|
elif new.last_change == self.last_change:
|
|
# same last_change: set allowed if one allowed
|
|
if new.allowed and not self.allowed:
|
|
changed = True
|
|
self.allowed = True
|
|
# set allow_until to max of both
|
|
if new.allow_until: # if not set nothing to change in local data
|
|
if not self.allow_until or self.allow_until < new.allow_until:
|
|
changed = True
|
|
self.allow_until = new.allow_until
|
|
return changed
|
|
|
|
def timeout(self) -> cptypes.Timestamp:
|
|
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
|
|
if self.allow_until:
|
|
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
|
|
if eau > elc:
|
|
return cptypes.Timestamp(epoch=eau)
|
|
return cptypes.Timestamp(epoch=elc)
|
|
|
|
# returns 0 if not allowed
|
|
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
|
|
if not self.allowed or not self.allow_until:
|
|
return 0
|
|
if not now:
|
|
now = cptypes.Timestamp.now()
|
|
assert self.allow_until
|
|
return max(self.allow_until.epoch - now.epoch, 0)
|
|
|
|
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
|
|
if not now:
|
|
now = cptypes.Timestamp.now()
|
|
return now.epoch > self.timeout().epoch
|
|
|
|
|
|
# might use this to serialize into file - don't need Message variant there
|
|
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
|
|
result: typing.List[capport.comm.message.MacStates] = []
|
|
current = capport.comm.message.MacStates()
|
|
for addr, entry in macs.items():
|
|
state = entry.to_state(addr)
|
|
current.states.append(state)
|
|
if len(current.states) >= 1024: # split into messages with 1024 states
|
|
result.append(current)
|
|
current = capport.comm.message.MacStates()
|
|
if len(current.states):
|
|
result.append(current)
|
|
return result
|
|
|
|
|
|
def _serialize_mac_states_as_messages(
|
|
macs: dict[cptypes.MacAddress, MacEntry],
|
|
) -> typing.List[capport.comm.message.Message]:
|
|
return [s.to_message() for s in _serialize_mac_states(macs)]
|
|
|
|
|
|
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
|
chunk = states.SerializeToString(deterministic=True)
|
|
chunk_size = len(chunk)
|
|
len_bytes = struct.pack('!I', chunk_size)
|
|
return len_bytes + chunk
|
|
|
|
|
|
class NotReadyYet(Exception):
|
|
def __init__(self, msg: str, wait: int):
|
|
self.wait = wait # seconds to wait
|
|
super().__init__(msg)
|
|
|
|
|
|
class Database:
|
|
def __init__(self, state_filename: typing.Optional[str] = None):
|
|
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
|
|
self._state_filename = state_filename
|
|
self._changed_since_last_cleanup = False
|
|
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
|
|
capport.comm.message.MacStates,
|
|
typing.List[capport.comm.message.MacStates],
|
|
]]] = None
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
|
|
pu = PendingUpdates(self)
|
|
pu._closed = False
|
|
yield pu
|
|
pu._finish()
|
|
if pu:
|
|
self._changed_since_last_cleanup = True
|
|
if self._send_changes:
|
|
for state in pu.serialized_states:
|
|
await self._send_changes.send(state)
|
|
|
|
def _drop_outdated(self) -> None:
|
|
done = False
|
|
while not done:
|
|
depr: typing.Set[cptypes.MacAddress] = set()
|
|
now = cptypes.Timestamp.now()
|
|
done = True
|
|
for mac, entry in self._macs.items():
|
|
if entry.outdated(now):
|
|
depr.add(mac)
|
|
if len(depr) >= 1024:
|
|
# clear entries found so far, then try again
|
|
done = False
|
|
break
|
|
if depr:
|
|
self._changed_since_last_cleanup = True
|
|
for mac in depr:
|
|
del self._macs[mac]
|
|
|
|
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
|
|
if self._state_filename:
|
|
await self._load_statefile()
|
|
task_status.started()
|
|
async with trio.open_nursery() as nursery:
|
|
if self._state_filename:
|
|
nursery.start_soon(self._run_statefile)
|
|
while True:
|
|
await trio.sleep(300) # cleanup every 5 minutes
|
|
_logger.debug("Running database cleanup")
|
|
self._drop_outdated()
|
|
if self._changed_since_last_cleanup:
|
|
self._changed_since_last_cleanup = False
|
|
if self._send_changes:
|
|
states = _serialize_mac_states(self._macs)
|
|
# trigger a resync
|
|
await self._send_changes.send(states)
|
|
|
|
# for initial handling of all data
|
|
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
|
return list(self._macs.items())
|
|
|
|
# for initial sync with new peer
|
|
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
|
return _serialize_mac_states_as_messages(self._macs)
|
|
|
|
def as_json(self) -> dict:
|
|
return {
|
|
str(addr): entry.as_json()
|
|
for addr, entry in self._macs.items()
|
|
}
|
|
|
|
async def _run_statefile(self) -> None:
|
|
rx: trio.MemoryReceiveChannel[typing.Union[
|
|
capport.comm.message.MacStates,
|
|
typing.List[capport.comm.message.MacStates],
|
|
]]
|
|
tx: trio.MemorySendChannel[typing.Union[
|
|
capport.comm.message.MacStates,
|
|
typing.List[capport.comm.message.MacStates],
|
|
]]
|
|
tx, rx = trio.open_memory_channel(64)
|
|
self._send_changes = tx
|
|
|
|
assert self._state_filename
|
|
filename: str = self._state_filename
|
|
tmp_filename = f'{filename}.new-{os.getpid()}'
|
|
|
|
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
|
|
try:
|
|
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
|
for states in all_states:
|
|
await tf.write(_states_to_chunk(states))
|
|
os.rename(tmp_filename, filename)
|
|
finally:
|
|
if os.path.exists(tmp_filename):
|
|
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
|
|
os.unlink(tmp_filename)
|
|
|
|
try:
|
|
while True:
|
|
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
|
|
while True:
|
|
update = await rx.receive()
|
|
if isinstance(update, list):
|
|
break
|
|
await sf.write(_states_to_chunk(update))
|
|
# got a "list" update - i.e. a resync
|
|
with trio.CancelScope(shield=True):
|
|
await resync(update)
|
|
# now reopen normal statefile and continue appending updates
|
|
except trio.Cancelled:
|
|
_logger.info('Final sync to disk')
|
|
with trio.CancelScope(shield=True):
|
|
await resync(_serialize_mac_states(self._macs))
|
|
_logger.info('Final sync to disk done')
|
|
|
|
async def _load_statefile(self):
|
|
if not os.path.exists(self._state_filename):
|
|
return
|
|
_logger.info("Loading statefile")
|
|
# we're going to ignore changes from loading the file
|
|
pu = PendingUpdates(self)
|
|
pu._closed = False
|
|
async with await trio.open_file(self._state_filename, 'rb') as sf:
|
|
while True:
|
|
try:
|
|
len_bytes = await sf.read(4)
|
|
if not len_bytes:
|
|
return
|
|
if len(len_bytes) < 4:
|
|
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
|
return
|
|
chunk_size, = struct.unpack('!I', len_bytes)
|
|
chunk = await sf.read(chunk_size)
|
|
except IOError as e:
|
|
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
|
return
|
|
try:
|
|
states = capport.comm.message.MacStates()
|
|
states.ParseFromString(chunk)
|
|
except google.protobuf.message.DecodeError as e:
|
|
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
|
|
continue
|
|
for state in states.states:
|
|
errors = 0
|
|
try:
|
|
pu.received_mac_state(state)
|
|
except Exception as e:
|
|
errors += 1
|
|
if errors < 5:
|
|
_logger.error(f'Failed to handle state: {e}')
|
|
|
|
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
|
entry = self._macs.get(mac)
|
|
if entry:
|
|
allowed_remaining = entry.allowed_remaining()
|
|
else:
|
|
allowed_remaining = 0
|
|
return cptypes.MacPublicState(
|
|
address=address,
|
|
mac=mac,
|
|
allowed_remaining=allowed_remaining,
|
|
)
|
|
|
|
|
|
class PendingUpdates:
|
|
def __init__(self, database: Database):
|
|
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
|
|
self._database = database
|
|
self._closed = True
|
|
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
|
|
self._serialized: typing.List[capport.comm.message.Message] = []
|
|
|
|
def __bool__(self) -> bool:
|
|
return bool(self._changes)
|
|
|
|
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
|
return self._changes.items()
|
|
|
|
@property
|
|
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
|
|
assert self._closed
|
|
return self._serialized_states
|
|
|
|
@property
|
|
def serialized(self) -> typing.List[capport.comm.message.Message]:
|
|
assert self._closed
|
|
return self._serialized
|
|
|
|
def _finish(self):
|
|
if self._closed:
|
|
raise Exception("Can't change closed PendingUpdates")
|
|
self._closed = True
|
|
self._serialized_states = _serialize_mac_states(self._changes)
|
|
self._serialized = [s.to_message() for s in self._serialized_states]
|
|
|
|
def received_mac_state(self, state: capport.comm.message.MacState):
|
|
if self._closed:
|
|
raise Exception("Can't change closed PendingUpdates")
|
|
(addr, new_entry) = MacEntry.parse_state(state)
|
|
old_entry = self._database._macs.get(addr)
|
|
if not old_entry:
|
|
# only redistribute if not outdated
|
|
if not new_entry.outdated():
|
|
self._database._macs[addr] = new_entry
|
|
self._changes[addr] = new_entry
|
|
elif old_entry.merge(new_entry):
|
|
if old_entry.outdated():
|
|
# remove local entry, but still redistribute
|
|
self._database._macs.pop(addr)
|
|
self._changes[addr] = old_entry
|
|
|
|
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8):
|
|
if self._closed:
|
|
raise Exception("Can't change closed PendingUpdates")
|
|
now = cptypes.Timestamp.now()
|
|
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
|
|
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
|
|
|
|
entry = self._database._macs.get(mac)
|
|
if not entry:
|
|
self._database._macs[mac] = new_entry
|
|
self._changes[mac] = new_entry
|
|
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
|
|
# too much time left on clock, not renewing session
|
|
return
|
|
elif entry.merge(new_entry):
|
|
self._changes[mac] = entry
|
|
elif not entry.allowed_remaining() > 0:
|
|
# entry should have been updated - can only fail due to `now < entry.last_change`
|
|
# i.e. out of sync clocks
|
|
wait = entry.last_change.epoch - now.epoch
|
|
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
|
|
|
|
def logout(self, mac: cptypes.MacAddress):
|
|
if self._closed:
|
|
raise Exception("Can't change closed PendingUpdates")
|
|
now = cptypes.Timestamp.now()
|
|
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
|
|
entry = self._database._macs.get(mac)
|
|
if entry:
|
|
if entry.merge(new_entry):
|
|
self._changes[mac] = entry
|
|
elif entry.allowed_remaining() > 0:
|
|
# still logged in. can only happen with `now <= entry.last_change`
|
|
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
|
|
wait = entry.last_change.epoch - now.epoch + 1
|
|
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
|