3
0
python-capport/src/capport/database.py
2023-01-12 13:16:58 +01:00

390 lines
15 KiB
Python

from __future__ import annotations
import contextlib
import dataclasses
import logging
import os
import struct
import typing
import google.protobuf.message
import trio
import capport.comm.message
from capport import cptypes
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class MacEntry:
# entry can be removed if last_change was some time ago and allow_until wasn't set
# or got reached.
WAIT_LAST_CHANGE_SECONDS = 60
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
# last_change: timestamp of last change (sent by system initiating the change)
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
if not last_change:
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
allow_until = 0
if self.allow_until:
allow_until = self.allow_until.epoch
return capport.comm.message.MacState(
mac_address=addr.raw,
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def as_json(self) -> dict:
allow_until = None
if self.allow_until:
allow_until = self.allow_until.epoch
return dict(
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def merge(self, new: MacEntry) -> bool:
changed = False
if new.last_change > self.last_change:
changed = True
self.last_change = new.last_change
self.allowed = new.allowed
elif new.last_change == self.last_change:
# same last_change: set allowed if one allowed
if new.allowed and not self.allowed:
changed = True
self.allowed = True
# set allow_until to max of both
if new.allow_until: # if not set nothing to change in local data
if not self.allow_until or self.allow_until < new.allow_until:
changed = True
self.allow_until = new.allow_until
return changed
def timeout(self) -> cptypes.Timestamp:
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
if self.allow_until:
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
if eau > elc:
return cptypes.Timestamp(epoch=eau)
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
now = cptypes.Timestamp.now()
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
current.states.append(state)
if len(current.states) >= 1024: # split into messages with 1024 states
result.append(current)
current = capport.comm.message.MacStates()
if len(current.states):
result.append(current)
return result
def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
return len_bytes + chunk
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
class Database:
def __init__(self, state_filename: typing.Optional[str] = None):
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]] = None
@contextlib.asynccontextmanager
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
pu = PendingUpdates(self)
pu._closed = False
yield pu
pu._finish()
if pu:
self._changed_since_last_cleanup = True
if self._send_changes:
for state in pu.serialized_states:
await self._send_changes.send(state)
def _drop_outdated(self) -> None:
done = False
while not done:
depr: typing.Set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now()
done = True
for mac, entry in self._macs.items():
if entry.outdated(now):
depr.add(mac)
if len(depr) >= 1024:
# clear entries found so far, then try again
done = False
break
if depr:
self._changed_since_last_cleanup = True
for mac in depr:
del self._macs[mac]
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
if self._state_filename:
await self._load_statefile()
task_status.started()
async with trio.open_nursery() as nursery:
if self._state_filename:
nursery.start_soon(self._run_statefile)
while True:
await trio.sleep(300) # cleanup every 5 minutes
_logger.debug("Running database cleanup")
self._drop_outdated()
if self._changed_since_last_cleanup:
self._changed_since_last_cleanup = False
if self._send_changes:
states = _serialize_mac_states(self._macs)
# trigger a resync
await self._send_changes.send(states)
# for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items())
# for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx: trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
assert self._state_filename
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states:
await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
os.unlink(tmp_filename)
try:
while True:
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
while True:
update = await rx.receive()
if isinstance(update, list):
break
await sf.write(_states_to_chunk(update))
# got a "list" update - i.e. a resync
with trio.CancelScope(shield=True):
await resync(update)
# now reopen normal statefile and continue appending updates
except trio.Cancelled:
_logger.info('Final sync to disk')
with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done')
async def _load_statefile(self):
if not os.path.exists(self._state_filename):
return
_logger.info("Loading statefile")
# we're going to ignore changes from loading the file
pu = PendingUpdates(self)
pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf:
while True:
try:
len_bytes = await sf.read(4)
if not len_bytes:
return
if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await sf.read(chunk_size)
except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}")
return
try:
states = capport.comm.message.MacStates()
states.ParseFromString(chunk)
except google.protobuf.message.DecodeError as e:
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
continue
for state in states.states:
errors = 0
try:
pu.received_mac_state(state)
except Exception as e:
errors += 1
if errors < 5:
_logger.error(f'Failed to handle state: {e}')
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
allowed_remaining = entry.allowed_remaining()
else:
allowed_remaining = 0
return cptypes.MacPublicState(
address=address,
mac=mac,
allowed_remaining=allowed_remaining,
)
class PendingUpdates:
def __init__(self, database: Database):
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = []
def __bool__(self) -> bool:
return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items()
@property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
assert self._closed
return self._serialized_states
@property
def serialized(self) -> typing.List[capport.comm.message.Message]:
assert self._closed
return self._serialized
def _finish(self):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
self._closed = True
self._serialized_states = _serialize_mac_states(self._changes)
self._serialized = [s.to_message() for s in self._serialized_states]
def received_mac_state(self, state: capport.comm.message.MacState):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._database._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._database._macs[addr] = new_entry
self._changes[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._database._macs.pop(addr)
self._changes[addr] = old_entry
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._database._macs.get(mac)
if not entry:
self._database._macs[mac] = new_entry
self._changes[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
self._changes[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._database._macs.get(mac)
if entry:
if entry.merge(new_entry):
self._changes[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)