2
0

sync controller database to disk and load it on start

This commit is contained in:
Stefan Bühler 2022-04-07 17:11:11 +02:00
parent 1e23b1205a
commit e1b1ec195f
6 changed files with 259 additions and 64 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ __pycache__
venv venv
capport.yaml capport.yaml
custom custom
capport.state
capport.state.new-*

View File

@ -28,7 +28,7 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
app.my_nc = mync app.my_nc = mync
_logger.info("Running hub for API") _logger.info("Running hub for API")
myapp = ApiHubApp() myapp = ApiHubApp()
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp) myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False)
app.my_hub = myhub app.my_hub = myhub
await myhub.run(task_status=task_status) await myhub.run(task_status=task_status)
finally: finally:

View File

@ -56,28 +56,28 @@ async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cp
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None: async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy assert app.my_hub # for mypy
pu = capport.database.PendingUpdates() async with app.my_hub.database.make_changes() as pu:
try: try:
app.my_hub.database.login(mac, app.my_config.session_timeout, pending_updates=pu) pu.login(mac, app.my_config.session_timeout)
except capport.database.NotReadyYet as e: except capport.database.NotReadyYet as e:
quart.abort(500, str(e)) quart.abort(500, str(e))
if pu.macs: if pu:
_logger.debug(f'User {mac} (with IP {address}) logged in') _logger.debug(f'User {mac} (with IP {address}) logged in')
for msg in pu.serialize(): for msg in pu.serialized:
await app.my_hub.broadcast(msg) await app.my_hub.broadcast(msg)
async def user_logout(mac: cptypes.MacAddress) -> None: async def user_logout(mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy assert app.my_hub # for mypy
pu = capport.database.PendingUpdates() async with app.my_hub.database.make_changes() as pu:
try: try:
app.my_hub.database.logout(mac, pending_updates=pu) pu.logout(mac)
except capport.database.NotReadyYet as e: except capport.database.NotReadyYet as e:
quart.abort(500, str(e)) quart.abort(500, str(e))
if pu.macs: if pu:
_logger.debug(f'User {mac} logged out') _logger.debug(f'User {mac} logged out')
for msg in pu.serialize(): for msg in pu.serialized:
await app.my_hub.broadcast(msg) await app.my_hub.broadcast(msg)

View File

@ -265,9 +265,6 @@ class ControllerConn:
class HubApplication: class HubApplication:
def is_controller(self) -> bool:
return False
async def new_peer(self, *, peer_id: uuid.UUID) -> None: async def new_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.info(f"New peer {peer_id}") _logger.info(f"New peer {peer_id}")
@ -287,13 +284,18 @@ class HubApplication:
class Hub: class Hub:
def __init__(self, config: Config, app: HubApplication) -> None: def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
self._config = config self._config = config
self._instance_id = uuid.uuid4() self._instance_id = uuid.uuid4()
self._hostname = socket.getfqdn() self._hostname = socket.getfqdn()
self.database = capport.database.Database()
self._app = app self._app = app
self._is_controller = bool(app.is_controller()) self._is_controller = is_controller
state_filename: typing.Optional[str]
if is_controller:
state_filename = 'capport.state'
else:
state_filename = None
self.database = capport.database.Database(state_filename=state_filename)
self._anon_context = ssl.SSLContext() self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2 self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
@ -324,6 +326,7 @@ class Hub:
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
await nursery.start(self.database.run)
if self._is_controller: if self._is_controller:
await nursery.start(self._listen) await nursery.start(self._listen)
@ -412,12 +415,12 @@ class Hub:
pass pass
elif isinstance(variant, capport.comm.message.MacStates): elif isinstance(variant, capport.comm.message.MacStates):
await self._app.received_mac_state(from_peer_id=peer_id, states=variant) await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
pu = capport.database.PendingUpdates() async with self.database.make_changes() as pu:
for state in variant.states: for state in variant.states:
self.database.received_mac_state(state, pending_updates=pu) pu.received_mac_state(state)
if pu.macs: if pu:
# re-broadcast all received updates to all peers # re-broadcast all received updates to all peers
await self.broadcast(*pu.serialize(), exclude=peer_id) await self.broadcast(*pu.serialized, exclude=peer_id)
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu) await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
else: else:
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg) await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)

View File

@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import typing
import uuid import uuid
import capport.database
import capport.comm.hub import capport.comm.hub
import capport.comm.message import capport.comm.message
import capport.config import capport.config
import capport.database
import capport.utils.cli import capport.utils.cli
import capport.utils.nft_set import capport.utils.nft_set
import trio import trio
@ -22,15 +23,15 @@ class ControlApp(capport.comm.hub.HubApplication):
super().__init__() super().__init__()
self.nft_set = capport.utils.nft_set.NftSet() self.nft_set = capport.utils.nft_set.NftSet()
def is_controller(self) -> bool:
return True
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
self.apply_db_entries(pending_updates.changes())
def apply_db_entries(self, entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]]) -> None:
# deploy changes to netfilter set # deploy changes to netfilter set
inserts = [] inserts = []
removals = [] removals = []
now = cptypes.Timestamp.now() now = cptypes.Timestamp.now()
for mac, state in pending_updates.macs.items(): for mac, state in entries:
rem = state.allowed_remaining(now) rem = state.allowed_remaining(now)
if rem > 0: if rem > 0:
inserts.append((mac, rem)) inserts.append((mac, rem))
@ -42,9 +43,12 @@ class ControlApp(capport.comm.hub.HubApplication):
async def amain(config: capport.config.Config) -> None: async def amain(config: capport.config.Config) -> None:
app = ControlApp() app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app) hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
app.hub = hub app.hub = hub
await hub.run() async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run)
app.apply_db_entries(hub.database.entries())
def main() -> None: def main() -> None:

View File

@ -1,11 +1,20 @@
from __future__ import annotations from __future__ import annotations
import contextlib
import dataclasses import dataclasses
import logging
import os
import struct
import typing import typing
import google.protobuf.message
import trio
import capport.comm.message import capport.comm.message
from capport import cptypes from capport import cptypes
_logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class MacEntry: class MacEntry:
@ -113,30 +122,82 @@ def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacE
return [s.to_message() for s in _serialize_mac_states(macs)] return [s.to_message() for s in _serialize_mac_states(macs)]
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
return len_bytes + chunk
class NotReadyYet(Exception): class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int): def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait self.wait = wait # seconds to wait
super().__init__(msg) super().__init__(msg)
@dataclasses.dataclass
class Database: class Database:
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) def __init__(self, state_filename: typing.Optional[str]=None):
self._macs: typing.Dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]] = None
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates): @contextlib.asynccontextmanager
(addr, new_entry) = MacEntry.parse_state(state) async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
old_entry = self._macs.get(addr) pu = PendingUpdates(self)
if not old_entry: pu._closed = False
# only redistribute if not outdated yield pu
if not new_entry.outdated(): pu._finish()
self._macs[addr] = new_entry if pu:
pending_updates.macs[addr] = new_entry self._changed_since_last_cleanup = True
elif old_entry.merge(new_entry): if self._send_changes:
if old_entry.outdated(): for state in pu.serialized_states:
# remove local entry, but still redistribute await self._send_changes.send(state)
self._macs.pop(addr)
pending_updates.macs[addr] = old_entry
def _drop_outdated(self) -> None:
done = False
while not done:
depr: typing.Set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now()
done = True
for mac, entry in self._macs.items():
if entry.outdated(now):
depr.add(mac)
if len(depr) >= 1024:
# clear entries found so far, then try again
done = False
break
if depr:
self._changed_since_last_cleanup = True
for mac in depr:
del self._macs[mac]
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
if self._state_filename:
await self._load_statefile()
task_status.started()
async with trio.open_nursery() as nursery:
if self._state_filename:
nursery.start_soon(self._run_statefile)
while True:
await trio.sleep(10) # sleep 15 minutes
_logger.debug("Running database cleanup")
self._drop_outdated()
if self._changed_since_last_cleanup:
self._changed_since_last_cleanup = False
if self._send_changes:
states = _serialize_mac_states(self._macs)
# trigger a resync
await self._send_changes.send(states)
# for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items())
# for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]: def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs) return _serialize_mac_states_as_messages(self._macs)
@ -146,6 +207,87 @@ class Database:
for addr, entry in self._macs.items() for addr, entry in self._macs.items()
} }
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx: trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
assert self._state_filename
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states:
await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
os.unlink(tmp_filename)
try:
while True:
async with await trio.open_file(filename, 'ab') as sf:
while True:
update = await rx.receive()
if isinstance(update, list):
break
await sf.write(_states_to_chunk(update))
# got a "list" update - i.e. a resync
with trio.CancelScope(shield=True):
await resync(update)
# now reopen normal statefile and continue appending updates
except trio.Cancelled:
_logger.info('Final sync to disk')
with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done')
async def _load_statefile(self):
if not os.path.exists(self._state_filename):
return
_logger.info("Loading statefile")
# we're going to ignore changes from loading the file
pu = PendingUpdates(self)
pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf:
while True:
try:
len_bytes = await sf.read(4)
if not len_bytes:
return
if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await sf.read(chunk_size)
except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}")
return
try:
states = capport.comm.message.MacStates()
states.ParseFromString(chunk)
except google.protobuf.message.DecodeError as e:
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
continue
for state in states:
errors = 0
try:
pu.received_mac_state(state)
except Exception as e:
errors += 1
if errors < 5:
_logger.error(f'Failed to handle state: {e}')
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac) entry = self._macs.get(mac)
if entry: if entry:
@ -158,43 +300,87 @@ class Database:
allowed_remaining=allowed_remaining, allowed_remaining=allowed_remaining,
) )
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
class PendingUpdates:
def __init__(self, database: Database):
self._changes: typing.Dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = []
def __bool__(self) -> bool:
return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items()
@property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
assert self._closed
return self._serialized_states
@property
def serialized(self) -> typing.List[capport.comm.message.Message]:
assert self._closed
return self._serialized
def _finish(self):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
self._closed = True
self._serialized_states = _serialize_mac_states(self._changes)
self._serialized = [s.to_message() for s in self._serialized_states]
def received_mac_state(self, state: capport.comm.message.MacState):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._database._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._database._macs[addr] = new_entry
self._changes[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._database._macs.pop(addr)
self._changes[addr] = old_entry
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float=0.8):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now() now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout) allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True) new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._macs.get(mac) entry = self._database._macs.get(mac)
if not entry: if not entry:
self._macs[mac] = new_entry self._database._macs[mac] = new_entry
pending_updates.macs[mac] = new_entry self._changes[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout: elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session # too much time left on clock, not renewing session
return return
elif entry.merge(new_entry): elif entry.merge(new_entry):
pending_updates.macs[mac] = entry self._changes[mac] = entry
elif not entry.allowed_remaining() > 0: elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change` # entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks # i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait) raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates): def logout(self, mac: cptypes.MacAddress):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now() now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False) new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._macs.get(mac) entry = self._database._macs.get(mac)
if entry: if entry:
if entry.merge(new_entry): if entry.merge(new_entry):
pending_updates.macs[mac] = entry self._changes[mac] = entry
elif entry.allowed_remaining() > 0: elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change` # still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in # clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1 wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait) raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
@dataclasses.dataclass
class PendingUpdates:
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self.macs)