2
0

sync controller database to disk and load it on start

This commit is contained in:
Stefan Bühler 2022-04-07 17:11:11 +02:00
parent 1e23b1205a
commit e1b1ec195f
6 changed files with 259 additions and 64 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ __pycache__
venv
capport.yaml
custom
capport.state
capport.state.new-*

View File

@ -28,7 +28,7 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
app.my_nc = mync
_logger.info("Running hub for API")
myapp = ApiHubApp()
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp)
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False)
app.my_hub = myhub
await myhub.run(task_status=task_status)
finally:

View File

@ -56,28 +56,28 @@ async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cp
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy
pu = capport.database.PendingUpdates()
async with app.my_hub.database.make_changes() as pu:
try:
app.my_hub.database.login(mac, app.my_config.session_timeout, pending_updates=pu)
pu.login(mac, app.my_config.session_timeout)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
if pu:
_logger.debug(f'User {mac} (with IP {address}) logged in')
for msg in pu.serialize():
for msg in pu.serialized:
await app.my_hub.broadcast(msg)
async def user_logout(mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy
pu = capport.database.PendingUpdates()
async with app.my_hub.database.make_changes() as pu:
try:
app.my_hub.database.logout(mac, pending_updates=pu)
pu.logout(mac)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
if pu:
_logger.debug(f'User {mac} logged out')
for msg in pu.serialize():
for msg in pu.serialized:
await app.my_hub.broadcast(msg)

View File

@ -265,9 +265,6 @@ class ControllerConn:
class HubApplication:
def is_controller(self) -> bool:
return False
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.info(f"New peer {peer_id}")
@ -287,13 +284,18 @@ class HubApplication:
class Hub:
def __init__(self, config: Config, app: HubApplication) -> None:
def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
self._config = config
self._instance_id = uuid.uuid4()
self._hostname = socket.getfqdn()
self.database = capport.database.Database()
self._app = app
self._is_controller = bool(app.is_controller())
self._is_controller = is_controller
state_filename: typing.Optional[str]
if is_controller:
state_filename = 'capport.state'
else:
state_filename = None
self.database = capport.database.Database(state_filename=state_filename)
self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
@ -324,6 +326,7 @@ class Hub:
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
await nursery.start(self.database.run)
if self._is_controller:
await nursery.start(self._listen)
@ -412,12 +415,12 @@ class Hub:
pass
elif isinstance(variant, capport.comm.message.MacStates):
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
pu = capport.database.PendingUpdates()
async with self.database.make_changes() as pu:
for state in variant.states:
self.database.received_mac_state(state, pending_updates=pu)
if pu.macs:
pu.received_mac_state(state)
if pu:
# re-broadcast all received updates to all peers
await self.broadcast(*pu.serialize(), exclude=peer_id)
await self.broadcast(*pu.serialized, exclude=peer_id)
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
else:
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import logging
import typing
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.database
import capport.utils.cli
import capport.utils.nft_set
import trio
@ -22,15 +23,15 @@ class ControlApp(capport.comm.hub.HubApplication):
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
def is_controller(self) -> bool:
return True
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
self.apply_db_entries(pending_updates.changes())
def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None:
# deploy changes to netfilter set
inserts = []
removals = []
now = cptypes.Timestamp.now()
for mac, state in pending_updates.macs.items():
for mac, state in entries:
rem = state.allowed_remaining(now)
if rem > 0:
inserts.append((mac, rem))
@ -42,9 +43,12 @@ class ControlApp(capport.comm.hub.HubApplication):
async def amain(config: capport.config.Config) -> None:
app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app)
hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
app.hub = hub
await hub.run()
async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run)
app.apply_db_entries(hub.database.entries())
def main() -> None:

View File

@ -1,11 +1,20 @@
from __future__ import annotations
import contextlib
import dataclasses
import logging
import os
import struct
import typing
import google.protobuf.message
import trio
import capport.comm.message
from capport import cptypes
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class MacEntry:
@ -113,30 +122,82 @@ def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacE
return [s.to_message() for s in _serialize_mac_states(macs)]
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
return len_bytes + chunk
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
@dataclasses.dataclass
class Database:
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def __init__(self, state_filename: typing.Optional[str]=None):
self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]] = None
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._macs[addr] = new_entry
pending_updates.macs[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._macs.pop(addr)
pending_updates.macs[addr] = old_entry
@contextlib.asynccontextmanager
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
pu = PendingUpdates(self)
pu._closed = False
yield pu
pu._finish()
if pu:
self._changed_since_last_cleanup = True
if self._send_changes:
for state in pu.serialized_states:
await self._send_changes.send(state)
def _drop_outdated(self) -> None:
done = False
while not done:
depr: typing.Set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now()
done = True
for mac, entry in self._macs.items():
if entry.outdated(now):
depr.add(mac)
if len(depr) >= 1024:
# clear entries found so far, then try again
done = False
break
if depr:
self._changed_since_last_cleanup = True
for mac in depr:
del self._macs[mac]
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
if self._state_filename:
await self._load_statefile()
task_status.started()
async with trio.open_nursery() as nursery:
if self._state_filename:
nursery.start_soon(self._run_statefile)
while True:
await trio.sleep(10) # sleep 15 minutes
_logger.debug("Running database cleanup")
self._drop_outdated()
if self._changed_since_last_cleanup:
self._changed_since_last_cleanup = False
if self._send_changes:
states = _serialize_mac_states(self._macs)
# trigger a resync
await self._send_changes.send(states)
# for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items())
# for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
@ -146,6 +207,87 @@ class Database:
for addr, entry in self._macs.items()
}
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx: trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
assert self._state_filename
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states:
await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
os.unlink(tmp_filename)
try:
while True:
async with await trio.open_file(filename, 'ab') as sf:
while True:
update = await rx.receive()
if isinstance(update, list):
break
await sf.write(_states_to_chunk(update))
# got a "list" update - i.e. a resync
with trio.CancelScope(shield=True):
await resync(update)
# now reopen normal statefile and continue appending updates
except trio.Cancelled:
_logger.info('Final sync to disk')
with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done')
async def _load_statefile(self):
if not os.path.exists(self._state_filename):
return
_logger.info("Loading statefile")
# we're going to ignore changes from loading the file
pu = PendingUpdates(self)
pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf:
while True:
try:
len_bytes = await sf.read(4)
if not len_bytes:
return
if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await sf.read(chunk_size)
except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}")
return
try:
states = capport.comm.message.MacStates()
states.ParseFromString(chunk)
except google.protobuf.message.DecodeError as e:
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
continue
for state in states:
errors = 0
try:
pu.received_mac_state(state)
except Exception as e:
errors += 1
if errors < 5:
_logger.error(f'Failed to handle state: {e}')
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
@ -158,43 +300,87 @@ class Database:
allowed_remaining=allowed_remaining,
)
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
class PendingUpdates:
def __init__(self, database: Database):
self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = []
def __bool__(self) -> bool:
return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items()
@property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
assert self._closed
return self._serialized_states
@property
def serialized(self) -> typing.List[capport.comm.message.Message]:
assert self._closed
return self._serialized
def _finish(self):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
self._closed = True
self._serialized_states = _serialize_mac_states(self._changes)
self._serialized = [s.to_message() for s in self._serialized_states]
def received_mac_state(self, state: capport.comm.message.MacState):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._database._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._database._macs[addr] = new_entry
self._changes[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._database._macs.pop(addr)
self._changes[addr] = old_entry
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._macs.get(mac)
entry = self._database._macs.get(mac)
if not entry:
self._macs[mac] = new_entry
pending_updates.macs[mac] = new_entry
self._database._macs[mac] = new_entry
self._changes[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
pending_updates.macs[mac] = entry
self._changes[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
def logout(self, mac: cptypes.MacAddress):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._macs.get(mac)
entry = self._database._macs.get(mac)
if entry:
if entry.merge(new_entry):
pending_updates.macs[mac] = entry
self._changes[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
@dataclasses.dataclass
class PendingUpdates:
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self.macs)